Browse Source

Merge pull request #5378 from thiswillbeyourgithub/fix_RAG_and_web

fix: RAG and Web Search + RAG enhancements
Timothy Jaeryang Baek 7 months ago
parent
commit
7dc4cb30b2
3 changed files with 73 additions and 30 deletions
  1. 52 20
      backend/open_webui/apps/rag/utils.py
  2. 16 10
      backend/open_webui/config.py
  3. 5 0
      backend/open_webui/main.py

+ 52 - 20
backend/open_webui/apps/rag/utils.py

@@ -1,5 +1,6 @@
 import logging
 import logging
 import os
 import os
+import uuid
 from typing import Optional, Union
 from typing import Optional, Union
 
 
 import requests
 import requests
@@ -91,7 +92,7 @@ def query_doc_with_hybrid_search(
     k: int,
     k: int,
     reranking_function,
     reranking_function,
     r: float,
     r: float,
-):
+) -> dict:
     try:
     try:
         result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
         result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
 
 
@@ -134,7 +135,7 @@ def query_doc_with_hybrid_search(
         raise e
         raise e
 
 
 
 
-def merge_and_sort_query_results(query_results, k, reverse=False):
+def merge_and_sort_query_results(query_results: list[dict], k: int, reverse: bool = False) -> list[dict]:
     # Initialize lists to store combined data
     # Initialize lists to store combined data
     combined_distances = []
     combined_distances = []
     combined_documents = []
     combined_documents = []
@@ -180,7 +181,7 @@ def query_collection(
     query: str,
     query: str,
     embedding_function,
     embedding_function,
     k: int,
     k: int,
-):
+) -> dict:
     results = []
     results = []
     for collection_name in collection_names:
     for collection_name in collection_names:
         if collection_name:
         if collection_name:
@@ -192,8 +193,8 @@ def query_collection(
                     embedding_function=embedding_function,
                     embedding_function=embedding_function,
                 )
                 )
                 results.append(result)
                 results.append(result)
-            except Exception:
-                pass
+            except Exception as e:
+                log.exception(f"Error when querying the collection: {e}")
         else:
         else:
             pass
             pass
 
 
@@ -207,8 +208,9 @@ def query_collection_with_hybrid_search(
     k: int,
     k: int,
     reranking_function,
     reranking_function,
     r: float,
     r: float,
-):
+) -> dict:
     results = []
     results = []
+    failed = 0
     for collection_name in collection_names:
     for collection_name in collection_names:
         try:
         try:
             result = query_doc_with_hybrid_search(
             result = query_doc_with_hybrid_search(
@@ -220,14 +222,39 @@ def query_collection_with_hybrid_search(
                 r=r,
                 r=r,
             )
             )
             results.append(result)
             results.append(result)
-        except Exception:
-            pass
+        except Exception as e:
+            log.exception(
+                "Error when querying the collection with "
+                f"hybrid_search: {e}"
+            )
+            failed += 1
+    if failed == len(collection_names):
+        raise Exception("Hybrid search failed for all collections. Using "
+                        "Non hybrid search as fallback.")
     return merge_and_sort_query_results(results, k=k, reverse=True)
     return merge_and_sort_query_results(results, k=k, reverse=True)
 
 
 
 
 def rag_template(template: str, context: str, query: str):
 def rag_template(template: str, context: str, query: str):
-    template = template.replace("[context]", context)
-    template = template.replace("[query]", query)
+    count = template.count("[context]")
+    assert count == 1, (
+        f"RAG template contains an unexpected number of '[context]' : {count}"
+    )
+    assert "[context]" in template, "RAG template does not contain '[context]'"
+    if "<context>" in context and "</context>" in context:
+        log.debug(
+            "WARNING: Potential prompt injection attack: the RAG "
+            "context contains '<context>' and '</context>'. This might be "
+            "nothing, or the user might be trying to hack something."
+        )
+
+    if "[query]" in context:
+        query_placeholder = str(uuid.uuid4())
+        template = template.replace("[QUERY]", query_placeholder)
+        template = template.replace("[context]", context)
+        template = template.replace(query_placeholder, query)
+    else:
+        template = template.replace("[context]", context)
+        template = template.replace("[query]", query)
     return template
     return template
 
 
 
 
@@ -304,19 +331,25 @@ def get_rag_context(
             continue
             continue
 
 
         try:
         try:
+            context = None
             if file["type"] == "text":
             if file["type"] == "text":
                 context = file["content"]
                 context = file["content"]
             else:
             else:
                 if hybrid_search:
                 if hybrid_search:
-                    context = query_collection_with_hybrid_search(
-                        collection_names=collection_names,
-                        query=query,
-                        embedding_function=embedding_function,
-                        k=k,
-                        reranking_function=reranking_function,
-                        r=r,
-                    )
-                else:
+                    try:
+                        context = query_collection_with_hybrid_search(
+                            collection_names=collection_names,
+                            query=query,
+                            embedding_function=embedding_function,
+                            k=k,
+                            reranking_function=reranking_function,
+                            r=r,
+                        )
+                    except Exception as e:
+                        log.debug("Error when using hybrid search, using"
+                                    " non hybrid search as fallback.")
+
+                if (not hybrid_search) or (context is None):
                     context = query_collection(
                     context = query_collection(
                         collection_names=collection_names,
                         collection_names=collection_names,
                         query=query,
                         query=query,
@@ -325,7 +358,6 @@ def get_rag_context(
                     )
                     )
         except Exception as e:
         except Exception as e:
             log.exception(e)
             log.exception(e)
-            context = None
 
 
         if context:
         if context:
             relevant_contexts.append({**context, "source": file})
             relevant_contexts.append({**context, "source": file})

+ 16 - 10
backend/open_webui/config.py

@@ -1030,19 +1030,25 @@ CHUNK_OVERLAP = PersistentConfig(
     int(os.environ.get("CHUNK_OVERLAP", "100")),
     int(os.environ.get("CHUNK_OVERLAP", "100")),
 )
 )
 
 
-DEFAULT_RAG_TEMPLATE = """Use the following context as your learned knowledge, inside <context></context> XML tags.
+DEFAULT_RAG_TEMPLATE = """You are given a user query, some textual context and rules, all inside xml tags. You have to answer the query based on the context while respecting the rules.
+
 <context>
 <context>
-    [context]
+[context]
 </context>
 </context>
 
 
-When answer to user:
-- If you don't know, just say that you don't know.
-- If you don't know when you are not sure, ask for clarification.
-Avoid mentioning that you obtained the information from the context.
-And answer according to the language of the user's question.
-
-Given the context information, answer the query.
-Query: [query]"""
+<rules>
+- If you don't know, just say so.
+- If you are not sure, ask for clarification.
+- Answer in the same language as the user query.
+- If the context appears unreadable or of poor quality, tell the user then answer as best as you can.
+- If the answer is not in the context but you think you know the answer, explain that to the user then answer with your own knowledge.
+- Answer directly and without using xml tags.
+</rules>
+
+<user_query>
+[query]
+</user_query>
+"""
 
 
 RAG_TEMPLATE = PersistentConfig(
 RAG_TEMPLATE = PersistentConfig(
     "RAG_TEMPLATE",
     "RAG_TEMPLATE",

+ 5 - 0
backend/open_webui/main.py

@@ -588,6 +588,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
             prompt = get_last_user_message(body["messages"])
             prompt = get_last_user_message(body["messages"])
             if prompt is None:
             if prompt is None:
                 raise Exception("No user message found")
                 raise Exception("No user message found")
+            if rag_app.state.config.RELEVANCE_THRESHOLD == 0:
+                assert context_string.strip(), (
+                    "With a 0 relevancy threshold for RAG, the context cannot "
+                    "be empty"
+                )
             # Workaround for Ollama 2.0+ system prompt issue
             # Workaround for Ollama 2.0+ system prompt issue
             # TODO: replace with add_or_update_system_message
             # TODO: replace with add_or_update_system_message
             if model["owned_by"] == "ollama":
             if model["owned_by"] == "ollama":