Bläddra i källkod

fix: sort ranking hybrid

Steven Kreitzer 1 år sedan
förälder
incheckning
69822e4c25
2 ändrade filer med 13 tillägg och 17 borttagningar
  1. 12 17
      backend/apps/rag/utils.py
  2. 1 0
      backend/main.py

+ 12 - 17
backend/apps/rag/utils.py

@@ -18,8 +18,6 @@ from langchain.retrievers import (
     EnsembleRetriever,
 )
 
-from sentence_transformers import CrossEncoder
-
 from typing import Optional
 from config import SRC_LOG_LEVELS, CHROMA_CLIENT
 
@@ -34,14 +32,13 @@ def query_embeddings_doc(
     embeddings_function,
     reranking_function,
     k: int,
-    r: Optional[float] = None,
-    hybrid: Optional[bool] = False,
+    r: int,
+    hybrid: bool,
 ):
     try:
-        if hybrid:
-            # if you use docker use the model from the environment variable
-            collection = CHROMA_CLIENT.get_collection(name=collection_name)
+        collection = CHROMA_CLIENT.get_collection(name=collection_name)
 
+        if hybrid:
             documents = collection.get()  # get all documents
             bm25_retriever = BM25Retriever.from_texts(
                 texts=documents.get("documents"),
@@ -77,24 +74,19 @@ def query_embeddings_doc(
                 "metadatas": [[d.metadata for d in result]],
             }
         else:
-            # if you use docker use the model from the environment variable
             query_embeddings = embeddings_function(query)
-
-            log.info(f"query_embeddings_doc {query_embeddings}")
-            collection = CHROMA_CLIENT.get_collection(name=collection_name)
-
             result = collection.query(
                 query_embeddings=[query_embeddings],
                 n_results=k,
             )
 
-            log.info(f"query_embeddings_doc:result {result}")
+        log.info(f"query_embeddings_doc:result {result}")
         return result
     except Exception as e:
         raise e
 
 
-def merge_and_sort_query_results(query_results, k):
+def merge_and_sort_query_results(query_results, k, reverse=False):
     # Initialize lists to store combined data
     combined_distances = []
     combined_documents = []
@@ -109,7 +101,7 @@ def merge_and_sort_query_results(query_results, k):
     combined = list(zip(combined_distances, combined_documents, combined_metadatas))
 
     # Sort the list based on distances
-    combined.sort(key=lambda x: x[0])
+    combined.sort(key=lambda x: x[0], reverse=reverse)
 
     # We don't have anything :-(
     if not combined:
@@ -162,7 +154,8 @@ def query_embeddings_collection(
         except:
             pass
 
-    return merge_and_sort_query_results(results, k)
+    reverse = hybrid and reranking_function is not None
+    return merge_and_sort_query_results(results, k=k, reverse=reverse)
 
 
 def rag_template(template: str, context: str, query: str):
@@ -484,7 +477,9 @@ class RerankCompressor(BaseDocumentCompressor):
                 (d, s) for d, s in docs_with_scores if s >= self.r_score
             ]
 
-        result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
+        reverse = self.reranking_function is not None
+        result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=reverse)
+
         final_results = []
         for doc, doc_score in result[: self.top_n]:
             metadata = doc.metadata

+ 1 - 0
backend/main.py

@@ -121,6 +121,7 @@ class RAGMiddleware(BaseHTTPMiddleware):
                     rag_app.state.RAG_TEMPLATE,
                     rag_app.state.TOP_K,
                     rag_app.state.RELEVANCE_THRESHOLD,
+                    rag_app.state.HYBRID,
                     rag_app.state.RAG_EMBEDDING_ENGINE,
                     rag_app.state.RAG_EMBEDDING_MODEL,
                     rag_app.state.sentence_transformer_ef,