|
@@ -18,6 +18,9 @@ from langchain.retrievers import (
|
|
|
EnsembleRetriever,
|
|
|
)
|
|
|
|
|
|
+from sentence_transformers import CrossEncoder
|
|
|
+
|
|
|
+from typing import Optional
|
|
|
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
|
|
|
|
|
|
|
|
@@ -28,50 +31,64 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
|
|
|
def query_embeddings_doc(
|
|
|
collection_name: str,
|
|
|
query: str,
|
|
|
- k: int,
|
|
|
- r: float,
|
|
|
embeddings_function,
|
|
|
- reranking_function,
|
|
|
+ k: int,
|
|
|
+ reranking_function: Optional[CrossEncoder] = None,
|
|
|
+ r: Optional[float] = None,
|
|
|
):
|
|
|
try:
|
|
|
- # if you use docker use the model from the environment variable
|
|
|
- collection = CHROMA_CLIENT.get_collection(name=collection_name)
|
|
|
|
|
|
- documents = collection.get() # get all documents
|
|
|
- bm25_retriever = BM25Retriever.from_texts(
|
|
|
- texts=documents.get("documents"),
|
|
|
- metadatas=documents.get("metadatas"),
|
|
|
- )
|
|
|
- bm25_retriever.k = k
|
|
|
+ if reranking_function:
|
|
|
+ # if you use docker use the model from the environment variable
|
|
|
+ collection = CHROMA_CLIENT.get_collection(name=collection_name)
|
|
|
|
|
|
- chroma_retriever = ChromaRetriever(
|
|
|
- collection=collection,
|
|
|
- embeddings_function=embeddings_function,
|
|
|
- top_n=k,
|
|
|
- )
|
|
|
+ documents = collection.get() # get all documents
|
|
|
+ bm25_retriever = BM25Retriever.from_texts(
|
|
|
+ texts=documents.get("documents"),
|
|
|
+ metadatas=documents.get("metadatas"),
|
|
|
+ )
|
|
|
+ bm25_retriever.k = k
|
|
|
|
|
|
- ensemble_retriever = EnsembleRetriever(
|
|
|
- retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5]
|
|
|
- )
|
|
|
+ chroma_retriever = ChromaRetriever(
|
|
|
+ collection=collection,
|
|
|
+ embeddings_function=embeddings_function,
|
|
|
+ top_n=k,
|
|
|
+ )
|
|
|
|
|
|
- compressor = RerankCompressor(
|
|
|
- embeddings_function=embeddings_function,
|
|
|
- reranking_function=reranking_function,
|
|
|
- r_score=r,
|
|
|
- top_n=k,
|
|
|
- )
|
|
|
+ ensemble_retriever = EnsembleRetriever(
|
|
|
+ retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5]
|
|
|
+ )
|
|
|
|
|
|
- compression_retriever = ContextualCompressionRetriever(
|
|
|
- base_compressor=compressor, base_retriever=ensemble_retriever
|
|
|
- )
|
|
|
+ compressor = RerankCompressor(
|
|
|
+ embeddings_function=embeddings_function,
|
|
|
+ reranking_function=reranking_function,
|
|
|
+ r_score=r,
|
|
|
+ top_n=k,
|
|
|
+ )
|
|
|
|
|
|
- result = compression_retriever.invoke(query)
|
|
|
- result = {
|
|
|
- "distances": [[d.metadata.get("score") for d in result]],
|
|
|
- "documents": [[d.page_content for d in result]],
|
|
|
- "metadatas": [[d.metadata for d in result]],
|
|
|
- }
|
|
|
+ compression_retriever = ContextualCompressionRetriever(
|
|
|
+ base_compressor=compressor, base_retriever=ensemble_retriever
|
|
|
+ )
|
|
|
+
|
|
|
+ result = compression_retriever.invoke(query)
|
|
|
+ result = {
|
|
|
+ "distances": [[d.metadata.get("score") for d in result]],
|
|
|
+ "documents": [[d.page_content for d in result]],
|
|
|
+ "metadatas": [[d.metadata for d in result]],
|
|
|
+ }
|
|
|
+ else:
|
|
|
+ # if you use docker use the model from the environment variable
|
|
|
+ query_embeddings = embeddings_function(query)
|
|
|
+
|
|
|
+ log.info(f"query_embeddings_doc {query_embeddings}")
|
|
|
+ collection = CHROMA_CLIENT.get_collection(name=collection_name)
|
|
|
+
|
|
|
+ result = collection.query(
|
|
|
+ query_embeddings=[query_embeddings],
|
|
|
+ n_results=k,
|
|
|
+ )
|
|
|
|
|
|
+ log.info(f"query_embeddings_doc:result {result}")
|
|
|
return result
|
|
|
except Exception as e:
|
|
|
raise e
|