Browse Source

revert: original rag pipeline

Timothy J. Baek 1 year ago
parent
commit
984dbf13ab
1 changed files with 51 additions and 34 deletions
  1. 51 34
      backend/apps/rag/utils.py

+ 51 - 34
backend/apps/rag/utils.py

@@ -18,6 +18,9 @@ from langchain.retrievers import (
     EnsembleRetriever,
 )
 
+from sentence_transformers import CrossEncoder
+
+from typing import Optional
 from config import SRC_LOG_LEVELS, CHROMA_CLIENT
 
 
@@ -28,50 +31,64 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
 def query_embeddings_doc(
     collection_name: str,
     query: str,
-    k: int,
-    r: float,
     embeddings_function,
-    reranking_function,
+    k: int,
+    reranking_function: Optional[CrossEncoder] = None,
+    r: Optional[float] = None,
 ):
     try:
-        # if you use docker use the model from the environment variable
-        collection = CHROMA_CLIENT.get_collection(name=collection_name)
 
-        documents = collection.get()  # get all documents
-        bm25_retriever = BM25Retriever.from_texts(
-            texts=documents.get("documents"),
-            metadatas=documents.get("metadatas"),
-        )
-        bm25_retriever.k = k
+        if reranking_function:
+            # if you use docker use the model from the environment variable
+            collection = CHROMA_CLIENT.get_collection(name=collection_name)
 
-        chroma_retriever = ChromaRetriever(
-            collection=collection,
-            embeddings_function=embeddings_function,
-            top_n=k,
-        )
+            documents = collection.get()  # get all documents
+            bm25_retriever = BM25Retriever.from_texts(
+                texts=documents.get("documents"),
+                metadatas=documents.get("metadatas"),
+            )
+            bm25_retriever.k = k
 
-        ensemble_retriever = EnsembleRetriever(
-            retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5]
-        )
+            chroma_retriever = ChromaRetriever(
+                collection=collection,
+                embeddings_function=embeddings_function,
+                top_n=k,
+            )
 
-        compressor = RerankCompressor(
-            embeddings_function=embeddings_function,
-            reranking_function=reranking_function,
-            r_score=r,
-            top_n=k,
-        )
+            ensemble_retriever = EnsembleRetriever(
+                retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5]
+            )
 
-        compression_retriever = ContextualCompressionRetriever(
-            base_compressor=compressor, base_retriever=ensemble_retriever
-        )
+            compressor = RerankCompressor(
+                embeddings_function=embeddings_function,
+                reranking_function=reranking_function,
+                r_score=r,
+                top_n=k,
+            )
 
-        result = compression_retriever.invoke(query)
-        result = {
-            "distances": [[d.metadata.get("score") for d in result]],
-            "documents": [[d.page_content for d in result]],
-            "metadatas": [[d.metadata for d in result]],
-        }
+            compression_retriever = ContextualCompressionRetriever(
+                base_compressor=compressor, base_retriever=ensemble_retriever
+            )
+
+            result = compression_retriever.invoke(query)
+            result = {
+                "distances": [[d.metadata.get("score") for d in result]],
+                "documents": [[d.page_content for d in result]],
+                "metadatas": [[d.metadata for d in result]],
+            }
+        else:
+            # if you use docker use the model from the environment variable
+            query_embeddings = embeddings_function(query)
+
+            log.info(f"query_embeddings_doc {query_embeddings}")
+            collection = CHROMA_CLIENT.get_collection(name=collection_name)
+
+            result = collection.query(
+                query_embeddings=[query_embeddings],
+                n_results=k,
+            )
 
+            log.info(f"query_embeddings_doc:result {result}")
         return result
     except Exception as e:
         raise e