|
@@ -1,5 +1,6 @@
|
|
|
import logging
|
|
|
import os
|
|
|
+import uuid
|
|
|
from typing import Optional, Union
|
|
|
|
|
|
import requests
|
|
@@ -91,7 +92,7 @@ def query_doc_with_hybrid_search(
|
|
|
k: int,
|
|
|
reranking_function,
|
|
|
r: float,
|
|
|
-):
|
|
|
+) -> dict:
|
|
|
try:
|
|
|
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
|
|
|
|
|
@@ -134,7 +135,7 @@ def query_doc_with_hybrid_search(
|
|
|
raise e
|
|
|
|
|
|
|
|
|
-def merge_and_sort_query_results(query_results, k, reverse=False):
|
|
|
+def merge_and_sort_query_results(query_results: list[dict], k: int, reverse: bool = False) -> list[dict]:
|
|
|
# Initialize lists to store combined data
|
|
|
combined_distances = []
|
|
|
combined_documents = []
|
|
@@ -180,7 +181,7 @@ def query_collection(
|
|
|
query: str,
|
|
|
embedding_function,
|
|
|
k: int,
|
|
|
-):
|
|
|
+) -> dict:
|
|
|
results = []
|
|
|
for collection_name in collection_names:
|
|
|
if collection_name:
|
|
@@ -192,8 +193,8 @@ def query_collection(
|
|
|
embedding_function=embedding_function,
|
|
|
)
|
|
|
results.append(result)
|
|
|
- except Exception:
|
|
|
- pass
|
|
|
+ except Exception as e:
|
|
|
+ log.exception(f"Error when querying the collection: {e}")
|
|
|
else:
|
|
|
pass
|
|
|
|
|
@@ -207,8 +208,9 @@ def query_collection_with_hybrid_search(
|
|
|
k: int,
|
|
|
reranking_function,
|
|
|
r: float,
|
|
|
-):
|
|
|
+) -> dict:
|
|
|
results = []
|
|
|
+ failed = 0
|
|
|
for collection_name in collection_names:
|
|
|
try:
|
|
|
result = query_doc_with_hybrid_search(
|
|
@@ -220,14 +222,39 @@ def query_collection_with_hybrid_search(
|
|
|
r=r,
|
|
|
)
|
|
|
results.append(result)
|
|
|
- except Exception:
|
|
|
- pass
|
|
|
+ except Exception as e:
|
|
|
+ log.exception(
|
|
|
+ "Error when querying the collection with "
|
|
|
+ f"hybrid_search: {e}"
|
|
|
+ )
|
|
|
+ failed += 1
|
|
|
+ if failed == len(collection_names):
|
|
|
+ raise Exception("Hybrid search failed for all collections. Using "
|
|
|
+ "Non hybrid search as fallback.")
|
|
|
return merge_and_sort_query_results(results, k=k, reverse=True)
|
|
|
|
|
|
|
|
|
def rag_template(template: str, context: str, query: str):
|
|
|
- template = template.replace("[context]", context)
|
|
|
- template = template.replace("[query]", query)
|
|
|
+ count = template.count("[context]")
|
|
|
+ assert count == 1, (
|
|
|
+ f"RAG template contains an unexpected number of '[context]' : {count}"
|
|
|
+ )
|
|
|
+ assert "[context]" in template, "RAG template does not contain '[context]'"
|
|
|
+ if "<context>" in context and "</context>" in context:
|
|
|
+ log.debug(
|
|
|
+ "WARNING: Potential prompt injection attack: the RAG "
|
|
|
+ "context contains '<context>' and '</context>'. This might be "
|
|
|
+ "nothing, or the user might be trying to hack something."
|
|
|
+ )
|
|
|
+
|
|
|
+ if "[query]" in context:
|
|
|
+ query_placeholder = str(uuid.uuid4())
|
|
|
+ template = template.replace("[QUERY]", query_placeholder)
|
|
|
+ template = template.replace("[context]", context)
|
|
|
+ template = template.replace(query_placeholder, query)
|
|
|
+ else:
|
|
|
+ template = template.replace("[context]", context)
|
|
|
+ template = template.replace("[query]", query)
|
|
|
return template
|
|
|
|
|
|
|
|
@@ -304,19 +331,25 @@ def get_rag_context(
|
|
|
continue
|
|
|
|
|
|
try:
|
|
|
+ context = None
|
|
|
if file["type"] == "text":
|
|
|
context = file["content"]
|
|
|
else:
|
|
|
if hybrid_search:
|
|
|
- context = query_collection_with_hybrid_search(
|
|
|
- collection_names=collection_names,
|
|
|
- query=query,
|
|
|
- embedding_function=embedding_function,
|
|
|
- k=k,
|
|
|
- reranking_function=reranking_function,
|
|
|
- r=r,
|
|
|
- )
|
|
|
- else:
|
|
|
+ try:
|
|
|
+ context = query_collection_with_hybrid_search(
|
|
|
+ collection_names=collection_names,
|
|
|
+ query=query,
|
|
|
+ embedding_function=embedding_function,
|
|
|
+ k=k,
|
|
|
+ reranking_function=reranking_function,
|
|
|
+ r=r,
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ log.debug("Error when using hybrid search, using"
|
|
|
+ " non hybrid search as fallback.")
|
|
|
+
|
|
|
+ if (not hybrid_search) or (context is None):
|
|
|
context = query_collection(
|
|
|
collection_names=collection_names,
|
|
|
query=query,
|
|
@@ -325,7 +358,6 @@ def get_rag_context(
|
|
|
)
|
|
|
except Exception as e:
|
|
|
log.exception(e)
|
|
|
- context = None
|
|
|
|
|
|
if context:
|
|
|
relevant_contexts.append({**context, "source": file})
|