Browse Source

refac: rag pipeline

Timothy J. Baek 1 year ago
parent
commit
ce9a5d12e0
3 changed files with 177 additions and 152 deletions
  1. 55 45
      backend/apps/rag/main.py
  2. 114 95
      backend/apps/rag/utils.py
  3. 8 12
      backend/main.py

+ 55 - 45
backend/apps/rag/main.py

@@ -47,9 +47,11 @@ from apps.web.models.documents import (
 
 from apps.rag.utils import (
     get_model_path,
-    query_embeddings_doc,
-    get_embeddings_function,
-    query_embeddings_collection,
+    get_embedding_function,
+    query_doc,
+    query_doc_with_hybrid_search,
+    query_collection,
+    query_collection_with_hybrid_search,
 )
 
 from utils.misc import (
@@ -147,6 +149,15 @@ update_reranking_model(
     RAG_RERANKING_MODEL_AUTO_UPDATE,
 )
 
+
+app.state.EMBEDDING_FUNCTION = get_embedding_function(
+    app.state.RAG_EMBEDDING_ENGINE,
+    app.state.RAG_EMBEDDING_MODEL,
+    app.state.sentence_transformer_ef,
+    app.state.OPENAI_API_KEY,
+    app.state.OPENAI_API_BASE_URL,
+)
+
 origins = ["*"]
 
 
@@ -227,6 +238,14 @@ async def update_embedding_config(
 
         update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True)
 
+        app.state.EMBEDDING_FUNCTION = get_embedding_function(
+            app.state.RAG_EMBEDDING_ENGINE,
+            app.state.RAG_EMBEDDING_MODEL,
+            app.state.sentence_transformer_ef,
+            app.state.OPENAI_API_KEY,
+            app.state.OPENAI_API_BASE_URL,
+        )
+
         return {
             "status": True,
             "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
@@ -367,27 +386,22 @@ def query_doc_handler(
     user=Depends(get_current_user),
 ):
     try:
-        embeddings_function = get_embeddings_function(
-            app.state.RAG_EMBEDDING_ENGINE,
-            app.state.RAG_EMBEDDING_MODEL,
-            app.state.sentence_transformer_ef,
-            app.state.OPENAI_API_KEY,
-            app.state.OPENAI_API_BASE_URL,
-        )
-
-        return query_embeddings_doc(
-            collection_name=form_data.collection_name,
-            query=form_data.query,
-            k=form_data.k if form_data.k else app.state.TOP_K,
-            r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
-            embeddings_function=embeddings_function,
-            reranking_function=app.state.sentence_transformer_rf,
-            hybrid_search=(
-                form_data.hybrid
-                if form_data.hybrid
-                else app.state.ENABLE_RAG_HYBRID_SEARCH
-            ),
-        )
+        if app.state.ENABLE_RAG_HYBRID_SEARCH:
+            return query_doc_with_hybrid_search(
+                collection_name=form_data.collection_name,
+                query=form_data.query,
+                embeddings_function=app.state.EMBEDDING_FUNCTION,
+                reranking_function=app.state.sentence_transformer_rf,
+                k=form_data.k if form_data.k else app.state.TOP_K,
+                r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
+            )
+        else:
+            return query_doc(
+                collection_name=form_data.collection_name,
+                query=form_data.query,
+                embeddings_function=app.state.EMBEDDING_FUNCTION,
+                k=form_data.k if form_data.k else app.state.TOP_K,
+            )
     except Exception as e:
         log.exception(e)
         raise HTTPException(
@@ -410,27 +424,23 @@ def query_collection_handler(
     user=Depends(get_current_user),
 ):
     try:
-        embeddings_function = get_embeddings_function(
-            app.state.RAG_EMBEDDING_ENGINE,
-            app.state.RAG_EMBEDDING_MODEL,
-            app.state.sentence_transformer_ef,
-            app.state.OPENAI_API_KEY,
-            app.state.OPENAI_API_BASE_URL,
-        )
+        if app.state.ENABLE_RAG_HYBRID_SEARCH:
+            return query_collection_with_hybrid_search(
+                collection_names=form_data.collection_names,
+                query=form_data.query,
+                embeddings_function=app.state.EMBEDDING_FUNCTION,
+                reranking_function=app.state.sentence_transformer_rf,
+                k=form_data.k if form_data.k else app.state.TOP_K,
+                r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
+            )
+        else:
+            return query_collection(
+                collection_names=form_data.collection_names,
+                query=form_data.query,
+                embeddings_function=app.state.EMBEDDING_FUNCTION,
+                k=form_data.k if form_data.k else app.state.TOP_K,
+            )
 
-        return query_embeddings_collection(
-            collection_names=form_data.collection_names,
-            query=form_data.query,
-            k=form_data.k if form_data.k else app.state.TOP_K,
-            r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
-            embeddings_function=embeddings_function,
-            reranking_function=app.state.sentence_transformer_rf,
-            hybrid_search=(
-                form_data.hybrid
-                if form_data.hybrid
-                else app.state.ENABLE_RAG_HYBRID_SEARCH
-            ),
-        )
     except Exception as e:
         log.exception(e)
         raise HTTPException(
@@ -508,7 +518,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
 
         collection = CHROMA_CLIENT.create_collection(name=collection_name)
 
-        embedding_func = get_embeddings_function(
+        embedding_func = get_embedding_function(
             app.state.RAG_EMBEDDING_ENGINE,
             app.state.RAG_EMBEDDING_MODEL,
             app.state.sentence_transformer_ef,

+ 114 - 95
backend/apps/rag/utils.py

@@ -26,61 +26,72 @@ log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 
 
-def query_embeddings_doc(
+def query_doc(
     collection_name: str,
     query: str,
-    embeddings_function,
-    reranking_function,
+    embedding_function,
     k: int,
-    r: int,
-    hybrid_search: bool,
 ):
     try:
         collection = CHROMA_CLIENT.get_collection(name=collection_name)
+        query_embeddings = embedding_function(query)
+        result = collection.query(
+            query_embeddings=[query_embeddings],
+            n_results=k,
+        )
 
-        if hybrid_search:
-            documents = collection.get()  # get all documents
-            bm25_retriever = BM25Retriever.from_texts(
-                texts=documents.get("documents"),
-                metadatas=documents.get("metadatas"),
-            )
-            bm25_retriever.k = k
+        log.info(f"query_doc:result {result}")
+        return result
+    except Exception as e:
+        raise e
 
-            chroma_retriever = ChromaRetriever(
-                collection=collection,
-                embeddings_function=embeddings_function,
-                top_n=k,
-            )
 
-            ensemble_retriever = EnsembleRetriever(
-                retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5]
-            )
+def query_doc_with_hybrid_search(
+    collection_name: str,
+    query: str,
+    embedding_function,
+    k: int,
+    reranking_function,
+    r: int,
+):
+    try:
+        collection = CHROMA_CLIENT.get_collection(name=collection_name)
+        documents = collection.get()  # get all documents
 
-            compressor = RerankCompressor(
-                embeddings_function=embeddings_function,
-                reranking_function=reranking_function,
-                r_score=r,
-                top_n=k,
-            )
+        bm25_retriever = BM25Retriever.from_texts(
+            texts=documents.get("documents"),
+            metadatas=documents.get("metadatas"),
+        )
+        bm25_retriever.k = k
 
-            compression_retriever = ContextualCompressionRetriever(
-                base_compressor=compressor, base_retriever=ensemble_retriever
-            )
+        chroma_retriever = ChromaRetriever(
+            collection=collection,
+            embedding_function=embedding_function,
+            top_n=k,
+        )
 
-            result = compression_retriever.invoke(query)
-            result = {
-                "distances": [[d.metadata.get("score") for d in result]],
-                "documents": [[d.page_content for d in result]],
-                "metadatas": [[d.metadata for d in result]],
-            }
-        else:
-            query_embeddings = embeddings_function(query)
-            result = collection.query(
-                query_embeddings=[query_embeddings],
-                n_results=k,
-            )
+        ensemble_retriever = EnsembleRetriever(
+            retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5]
+        )
+
+        compressor = RerankCompressor(
+            embedding_function=embedding_function,
+            reranking_function=reranking_function,
+            r_score=r,
+            top_n=k,
+        )
+
+        compression_retriever = ContextualCompressionRetriever(
+            base_compressor=compressor, base_retriever=ensemble_retriever
+        )
 
-        log.info(f"query_embeddings_doc:result {result}")
+        result = compression_retriever.invoke(query)
+        result = {
+            "distances": [[d.metadata.get("score") for d in result]],
+            "documents": [[d.page_content for d in result]],
+            "metadatas": [[d.metadata for d in result]],
+        }
+        log.info(f"query_doc_with_hybrid_search:result {result}")
         return result
     except Exception as e:
         raise e
@@ -127,35 +138,52 @@ def merge_and_sort_query_results(query_results, k, reverse=False):
     return result
 
 
-def query_embeddings_collection(
+def query_collection(
     collection_names: List[str],
     query: str,
+    embedding_function,
+    k: int,
+):
+    results = []
+    for collection_name in collection_names:
+        try:
+            result = query_doc(
+                collection_name=collection_name,
+                query=query,
+                k=k,
+                embedding_function=embedding_function,
+            )
+            results.append(result)
+        except:
+            pass
+    return merge_and_sort_query_results(results, k=k)
+
+
+def query_collection_with_hybrid_search(
+    collection_names: List[str],
+    query: str,
+    embedding_function,
     k: int,
-    r: float,
-    embeddings_function,
     reranking_function,
-    hybrid_search: bool,
+    r: float,
 ):
 
     results = []
-
     for collection_name in collection_names:
         try:
-            result = query_embeddings_doc(
+            result = query_doc_with_hybrid_search(
                 collection_name=collection_name,
                 query=query,
+                embedding_function=embedding_function,
                 k=k,
-                r=r,
-                embeddings_function=embeddings_function,
                 reranking_function=reranking_function,
-                hybrid_search=hybrid_search,
+                r=r,
             )
             results.append(result)
         except:
             pass
 
-    reverse = hybrid_search and reranking_function is not None
-    return merge_and_sort_query_results(results, k=k, reverse=reverse)
+    return merge_and_sort_query_results(results, k=k, reverse=True)
 
 
 def rag_template(template: str, context: str, query: str):
@@ -164,7 +192,7 @@ def rag_template(template: str, context: str, query: str):
     return template
 
 
-def get_embeddings_function(
+def get_embedding_function(
     embedding_engine,
     embedding_model,
     embedding_function,
@@ -204,19 +232,13 @@ def rag_messages(
     docs,
     messages,
     template,
+    embedding_function,
     k,
+    reranking_function,
     r,
     hybrid_search,
-    embedding_engine,
-    embedding_model,
-    embedding_function,
-    reranking_function,
-    openai_key,
-    openai_url,
 ):
-    log.debug(
-        f"docs: {docs} {messages} {embedding_engine} {embedding_model} {embedding_function} {reranking_function} {openai_key} {openai_url}"
-    )
+    log.debug(f"docs: {docs} {messages} {embedding_function} {reranking_function}")
 
     last_user_message_idx = None
     for i in range(len(messages) - 1, -1, -1):
@@ -243,14 +265,6 @@ def rag_messages(
         content_type = None
         query = ""
 
-    embeddings_function = get_embeddings_function(
-        embedding_engine,
-        embedding_model,
-        embedding_function,
-        openai_key,
-        openai_url,
-    )
-
     extracted_collections = []
     relevant_contexts = []
 
@@ -271,26 +285,31 @@ def rag_messages(
         try:
             if doc["type"] == "text":
                 context = doc["content"]
-            elif doc["type"] == "collection":
-                context = query_embeddings_collection(
-                    collection_names=doc["collection_names"],
-                    query=query,
-                    k=k,
-                    r=r,
-                    embeddings_function=embeddings_function,
-                    reranking_function=reranking_function,
-                    hybrid_search=hybrid_search,
-                )
             else:
-                context = query_embeddings_doc(
-                    collection_name=doc["collection_name"],
-                    query=query,
-                    k=k,
-                    r=r,
-                    embeddings_function=embeddings_function,
-                    reranking_function=reranking_function,
-                    hybrid_search=hybrid_search,
-                )
+                if hybrid_search:
+                    context = query_collection_with_hybrid_search(
+                        collection_names=(
+                            doc["collection_names"]
+                            if doc["type"] == "collection"
+                            else [doc["collection_name"]]
+                        ),
+                        query=query,
+                        embedding_function=embedding_function,
+                        k=k,
+                        reranking_function=reranking_function,
+                        r=r,
+                    )
+                else:
+                    context = query_collection(
+                        collection_names=(
+                            doc["collection_names"]
+                            if doc["type"] == "collection"
+                            else [doc["collection_name"]]
+                        ),
+                        query=query,
+                        embedding_function=embedding_function,
+                        k=k,
+                    )
         except Exception as e:
             log.exception(e)
             context = None
@@ -404,7 +423,7 @@ from langchain_core.callbacks import CallbackManagerForRetrieverRun
 
 class ChromaRetriever(BaseRetriever):
     collection: Any
-    embeddings_function: Any
+    embedding_function: Any
     top_n: int
 
     def _get_relevant_documents(
@@ -413,7 +432,7 @@ class ChromaRetriever(BaseRetriever):
         *,
         run_manager: CallbackManagerForRetrieverRun,
     ) -> List[Document]:
-        query_embeddings = self.embeddings_function(query)
+        query_embeddings = self.embedding_function(query)
 
         results = self.collection.query(
             query_embeddings=[query_embeddings],
@@ -445,7 +464,7 @@ from sentence_transformers import util
 
 
 class RerankCompressor(BaseDocumentCompressor):
-    embeddings_function: Any
+    embedding_function: Any
     reranking_function: Any
     r_score: float
     top_n: int
@@ -465,8 +484,8 @@ class RerankCompressor(BaseDocumentCompressor):
                 [(query, doc.page_content) for doc in documents]
             )
         else:
-            query_embedding = self.embeddings_function(query)
-            document_embedding = self.embeddings_function(
+            query_embedding = self.embedding_function(query)
+            document_embedding = self.embedding_function(
                 [doc.page_content for doc in documents]
             )
             scores = util.cos_sim(query_embedding, document_embedding)[0]

+ 8 - 12
backend/main.py

@@ -117,18 +117,14 @@ class RAGMiddleware(BaseHTTPMiddleware):
             if "docs" in data:
                 data = {**data}
                 data["messages"] = rag_messages(
-                    data["docs"],
-                    data["messages"],
-                    rag_app.state.RAG_TEMPLATE,
-                    rag_app.state.TOP_K,
-                    rag_app.state.RELEVANCE_THRESHOLD,
-                    rag_app.state.ENABLE_RAG_HYBRID_SEARCH,
-                    rag_app.state.RAG_EMBEDDING_ENGINE,
-                    rag_app.state.RAG_EMBEDDING_MODEL,
-                    rag_app.state.sentence_transformer_ef,
-                    rag_app.state.sentence_transformer_rf,
-                    rag_app.state.OPENAI_API_KEY,
-                    rag_app.state.OPENAI_API_BASE_URL,
+                    docs=data["docs"],
+                    messages=data["messages"],
+                    template=rag_app.state.RAG_TEMPLATE,
+                    embedding_function=rag_app.state.EMBEDDING_FUNCTION,
+                    k=rag_app.state.TOP_K,
+                    reranking_function=rag_app.state.sentence_transformer_rf,
+                    r=rag_app.state.RELEVANCE_THRESHOLD,
+                    hybrid_search=rag_app.state.ENABLE_RAG_HYBRID_SEARCH,
                 )
                 del data["docs"]