|
@@ -26,61 +26,72 @@ log = logging.getLogger(__name__)
|
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
|
|
|
|
|
|
|
|
|
-def query_embeddings_doc(
|
|
|
|
|
|
+def query_doc(
|
|
collection_name: str,
|
|
collection_name: str,
|
|
query: str,
|
|
query: str,
|
|
- embeddings_function,
|
|
|
|
- reranking_function,
|
|
|
|
|
|
+ embedding_function,
|
|
k: int,
|
|
k: int,
|
|
- r: int,
|
|
|
|
- hybrid_search: bool,
|
|
|
|
):
|
|
):
|
|
try:
|
|
try:
|
|
collection = CHROMA_CLIENT.get_collection(name=collection_name)
|
|
collection = CHROMA_CLIENT.get_collection(name=collection_name)
|
|
|
|
+ query_embeddings = embedding_function(query)
|
|
|
|
+ result = collection.query(
|
|
|
|
+ query_embeddings=[query_embeddings],
|
|
|
|
+ n_results=k,
|
|
|
|
+ )
|
|
|
|
|
|
- if hybrid_search:
|
|
|
|
- documents = collection.get() # get all documents
|
|
|
|
- bm25_retriever = BM25Retriever.from_texts(
|
|
|
|
- texts=documents.get("documents"),
|
|
|
|
- metadatas=documents.get("metadatas"),
|
|
|
|
- )
|
|
|
|
- bm25_retriever.k = k
|
|
|
|
|
|
+ log.info(f"query_doc:result {result}")
|
|
|
|
+ return result
|
|
|
|
+ except Exception as e:
|
|
|
|
+ raise e
|
|
|
|
|
|
- chroma_retriever = ChromaRetriever(
|
|
|
|
- collection=collection,
|
|
|
|
- embeddings_function=embeddings_function,
|
|
|
|
- top_n=k,
|
|
|
|
- )
|
|
|
|
|
|
|
|
- ensemble_retriever = EnsembleRetriever(
|
|
|
|
- retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5]
|
|
|
|
- )
|
|
|
|
|
|
+def query_doc_with_hybrid_search(
|
|
|
|
+ collection_name: str,
|
|
|
|
+ query: str,
|
|
|
|
+ embedding_function,
|
|
|
|
+ k: int,
|
|
|
|
+ reranking_function,
|
|
|
|
+ r: int,
|
|
|
|
+):
|
|
|
|
+ try:
|
|
|
|
+ collection = CHROMA_CLIENT.get_collection(name=collection_name)
|
|
|
|
+ documents = collection.get() # get all documents
|
|
|
|
|
|
- compressor = RerankCompressor(
|
|
|
|
- embeddings_function=embeddings_function,
|
|
|
|
- reranking_function=reranking_function,
|
|
|
|
- r_score=r,
|
|
|
|
- top_n=k,
|
|
|
|
- )
|
|
|
|
|
|
+ bm25_retriever = BM25Retriever.from_texts(
|
|
|
|
+ texts=documents.get("documents"),
|
|
|
|
+ metadatas=documents.get("metadatas"),
|
|
|
|
+ )
|
|
|
|
+ bm25_retriever.k = k
|
|
|
|
|
|
- compression_retriever = ContextualCompressionRetriever(
|
|
|
|
- base_compressor=compressor, base_retriever=ensemble_retriever
|
|
|
|
- )
|
|
|
|
|
|
+ chroma_retriever = ChromaRetriever(
|
|
|
|
+ collection=collection,
|
|
|
|
+ embedding_function=embedding_function,
|
|
|
|
+ top_n=k,
|
|
|
|
+ )
|
|
|
|
|
|
- result = compression_retriever.invoke(query)
|
|
|
|
- result = {
|
|
|
|
- "distances": [[d.metadata.get("score") for d in result]],
|
|
|
|
- "documents": [[d.page_content for d in result]],
|
|
|
|
- "metadatas": [[d.metadata for d in result]],
|
|
|
|
- }
|
|
|
|
- else:
|
|
|
|
- query_embeddings = embeddings_function(query)
|
|
|
|
- result = collection.query(
|
|
|
|
- query_embeddings=[query_embeddings],
|
|
|
|
- n_results=k,
|
|
|
|
- )
|
|
|
|
|
|
+ ensemble_retriever = EnsembleRetriever(
|
|
|
|
+ retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5]
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ compressor = RerankCompressor(
|
|
|
|
+ embedding_function=embedding_function,
|
|
|
|
+ reranking_function=reranking_function,
|
|
|
|
+ r_score=r,
|
|
|
|
+ top_n=k,
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ compression_retriever = ContextualCompressionRetriever(
|
|
|
|
+ base_compressor=compressor, base_retriever=ensemble_retriever
|
|
|
|
+ )
|
|
|
|
|
|
- log.info(f"query_embeddings_doc:result {result}")
|
|
|
|
|
|
+ result = compression_retriever.invoke(query)
|
|
|
|
+ result = {
|
|
|
|
+ "distances": [[d.metadata.get("score") for d in result]],
|
|
|
|
+ "documents": [[d.page_content for d in result]],
|
|
|
|
+ "metadatas": [[d.metadata for d in result]],
|
|
|
|
+ }
|
|
|
|
+ log.info(f"query_doc_with_hybrid_search:result {result}")
|
|
return result
|
|
return result
|
|
except Exception as e:
|
|
except Exception as e:
|
|
raise e
|
|
raise e
|
|
@@ -127,35 +138,52 @@ def merge_and_sort_query_results(query_results, k, reverse=False):
|
|
return result
|
|
return result
|
|
|
|
|
|
|
|
|
|
-def query_embeddings_collection(
|
|
|
|
|
|
+def query_collection(
|
|
collection_names: List[str],
|
|
collection_names: List[str],
|
|
query: str,
|
|
query: str,
|
|
|
|
+ embedding_function,
|
|
|
|
+ k: int,
|
|
|
|
+):
|
|
|
|
+ results = []
|
|
|
|
+ for collection_name in collection_names:
|
|
|
|
+ try:
|
|
|
|
+ result = query_doc(
|
|
|
|
+ collection_name=collection_name,
|
|
|
|
+ query=query,
|
|
|
|
+ k=k,
|
|
|
|
+ embedding_function=embedding_function,
|
|
|
|
+ )
|
|
|
|
+ results.append(result)
|
|
|
|
+ except:
|
|
|
|
+ pass
|
|
|
|
+ return merge_and_sort_query_results(results, k=k)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def query_collection_with_hybrid_search(
|
|
|
|
+ collection_names: List[str],
|
|
|
|
+ query: str,
|
|
|
|
+ embedding_function,
|
|
k: int,
|
|
k: int,
|
|
- r: float,
|
|
|
|
- embeddings_function,
|
|
|
|
reranking_function,
|
|
reranking_function,
|
|
- hybrid_search: bool,
|
|
|
|
|
|
+ r: float,
|
|
):
|
|
):
|
|
|
|
|
|
results = []
|
|
results = []
|
|
-
|
|
|
|
for collection_name in collection_names:
|
|
for collection_name in collection_names:
|
|
try:
|
|
try:
|
|
- result = query_embeddings_doc(
|
|
|
|
|
|
+ result = query_doc_with_hybrid_search(
|
|
collection_name=collection_name,
|
|
collection_name=collection_name,
|
|
query=query,
|
|
query=query,
|
|
|
|
+ embedding_function=embedding_function,
|
|
k=k,
|
|
k=k,
|
|
- r=r,
|
|
|
|
- embeddings_function=embeddings_function,
|
|
|
|
reranking_function=reranking_function,
|
|
reranking_function=reranking_function,
|
|
- hybrid_search=hybrid_search,
|
|
|
|
|
|
+ r=r,
|
|
)
|
|
)
|
|
results.append(result)
|
|
results.append(result)
|
|
except:
|
|
except:
|
|
pass
|
|
pass
|
|
|
|
|
|
- reverse = hybrid_search and reranking_function is not None
|
|
|
|
- return merge_and_sort_query_results(results, k=k, reverse=reverse)
|
|
|
|
|
|
+ return merge_and_sort_query_results(results, k=k, reverse=True)
|
|
|
|
|
|
|
|
|
|
def rag_template(template: str, context: str, query: str):
|
|
def rag_template(template: str, context: str, query: str):
|
|
@@ -164,7 +192,7 @@ def rag_template(template: str, context: str, query: str):
|
|
return template
|
|
return template
|
|
|
|
|
|
|
|
|
|
-def get_embeddings_function(
|
|
|
|
|
|
+def get_embedding_function(
|
|
embedding_engine,
|
|
embedding_engine,
|
|
embedding_model,
|
|
embedding_model,
|
|
embedding_function,
|
|
embedding_function,
|
|
@@ -204,19 +232,13 @@ def rag_messages(
|
|
docs,
|
|
docs,
|
|
messages,
|
|
messages,
|
|
template,
|
|
template,
|
|
|
|
+ embedding_function,
|
|
k,
|
|
k,
|
|
|
|
+ reranking_function,
|
|
r,
|
|
r,
|
|
hybrid_search,
|
|
hybrid_search,
|
|
- embedding_engine,
|
|
|
|
- embedding_model,
|
|
|
|
- embedding_function,
|
|
|
|
- reranking_function,
|
|
|
|
- openai_key,
|
|
|
|
- openai_url,
|
|
|
|
):
|
|
):
|
|
- log.debug(
|
|
|
|
- f"docs: {docs} {messages} {embedding_engine} {embedding_model} {embedding_function} {reranking_function} {openai_key} {openai_url}"
|
|
|
|
- )
|
|
|
|
|
|
+ log.debug(f"docs: {docs} {messages} {embedding_function} {reranking_function}")
|
|
|
|
|
|
last_user_message_idx = None
|
|
last_user_message_idx = None
|
|
for i in range(len(messages) - 1, -1, -1):
|
|
for i in range(len(messages) - 1, -1, -1):
|
|
@@ -243,14 +265,6 @@ def rag_messages(
|
|
content_type = None
|
|
content_type = None
|
|
query = ""
|
|
query = ""
|
|
|
|
|
|
- embeddings_function = get_embeddings_function(
|
|
|
|
- embedding_engine,
|
|
|
|
- embedding_model,
|
|
|
|
- embedding_function,
|
|
|
|
- openai_key,
|
|
|
|
- openai_url,
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
extracted_collections = []
|
|
extracted_collections = []
|
|
relevant_contexts = []
|
|
relevant_contexts = []
|
|
|
|
|
|
@@ -271,26 +285,31 @@ def rag_messages(
|
|
try:
|
|
try:
|
|
if doc["type"] == "text":
|
|
if doc["type"] == "text":
|
|
context = doc["content"]
|
|
context = doc["content"]
|
|
- elif doc["type"] == "collection":
|
|
|
|
- context = query_embeddings_collection(
|
|
|
|
- collection_names=doc["collection_names"],
|
|
|
|
- query=query,
|
|
|
|
- k=k,
|
|
|
|
- r=r,
|
|
|
|
- embeddings_function=embeddings_function,
|
|
|
|
- reranking_function=reranking_function,
|
|
|
|
- hybrid_search=hybrid_search,
|
|
|
|
- )
|
|
|
|
else:
|
|
else:
|
|
- context = query_embeddings_doc(
|
|
|
|
- collection_name=doc["collection_name"],
|
|
|
|
- query=query,
|
|
|
|
- k=k,
|
|
|
|
- r=r,
|
|
|
|
- embeddings_function=embeddings_function,
|
|
|
|
- reranking_function=reranking_function,
|
|
|
|
- hybrid_search=hybrid_search,
|
|
|
|
- )
|
|
|
|
|
|
+ if hybrid_search:
|
|
|
|
+ context = query_collection_with_hybrid_search(
|
|
|
|
+ collection_names=(
|
|
|
|
+ doc["collection_names"]
|
|
|
|
+ if doc["type"] == "collection"
|
|
|
|
+ else [doc["collection_name"]]
|
|
|
|
+ ),
|
|
|
|
+ query=query,
|
|
|
|
+ embedding_function=embedding_function,
|
|
|
|
+ k=k,
|
|
|
|
+ reranking_function=reranking_function,
|
|
|
|
+ r=r,
|
|
|
|
+ )
|
|
|
|
+ else:
|
|
|
|
+ context = query_collection(
|
|
|
|
+ collection_names=(
|
|
|
|
+ doc["collection_names"]
|
|
|
|
+ if doc["type"] == "collection"
|
|
|
|
+ else [doc["collection_name"]]
|
|
|
|
+ ),
|
|
|
|
+ query=query,
|
|
|
|
+ embedding_function=embedding_function,
|
|
|
|
+ k=k,
|
|
|
|
+ )
|
|
except Exception as e:
|
|
except Exception as e:
|
|
log.exception(e)
|
|
log.exception(e)
|
|
context = None
|
|
context = None
|
|
@@ -404,7 +423,7 @@ from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
|
|
|
|
|
class ChromaRetriever(BaseRetriever):
|
|
class ChromaRetriever(BaseRetriever):
|
|
collection: Any
|
|
collection: Any
|
|
- embeddings_function: Any
|
|
|
|
|
|
+ embedding_function: Any
|
|
top_n: int
|
|
top_n: int
|
|
|
|
|
|
def _get_relevant_documents(
|
|
def _get_relevant_documents(
|
|
@@ -413,7 +432,7 @@ class ChromaRetriever(BaseRetriever):
|
|
*,
|
|
*,
|
|
run_manager: CallbackManagerForRetrieverRun,
|
|
run_manager: CallbackManagerForRetrieverRun,
|
|
) -> List[Document]:
|
|
) -> List[Document]:
|
|
- query_embeddings = self.embeddings_function(query)
|
|
|
|
|
|
+ query_embeddings = self.embedding_function(query)
|
|
|
|
|
|
results = self.collection.query(
|
|
results = self.collection.query(
|
|
query_embeddings=[query_embeddings],
|
|
query_embeddings=[query_embeddings],
|
|
@@ -445,7 +464,7 @@ from sentence_transformers import util
|
|
|
|
|
|
|
|
|
|
class RerankCompressor(BaseDocumentCompressor):
|
|
class RerankCompressor(BaseDocumentCompressor):
|
|
- embeddings_function: Any
|
|
|
|
|
|
+ embedding_function: Any
|
|
reranking_function: Any
|
|
reranking_function: Any
|
|
r_score: float
|
|
r_score: float
|
|
top_n: int
|
|
top_n: int
|
|
@@ -465,8 +484,8 @@ class RerankCompressor(BaseDocumentCompressor):
|
|
[(query, doc.page_content) for doc in documents]
|
|
[(query, doc.page_content) for doc in documents]
|
|
)
|
|
)
|
|
else:
|
|
else:
|
|
- query_embedding = self.embeddings_function(query)
|
|
|
|
- document_embedding = self.embeddings_function(
|
|
|
|
|
|
+ query_embedding = self.embedding_function(query)
|
|
|
|
+ document_embedding = self.embedding_function(
|
|
[doc.page_content for doc in documents]
|
|
[doc.page_content for doc in documents]
|
|
)
|
|
)
|
|
scores = util.cos_sim(query_embedding, document_embedding)[0]
|
|
scores = util.cos_sim(query_embedding, document_embedding)[0]
|