|
@@ -35,6 +35,7 @@ def query_doc(
|
|
|
try:
|
|
|
collection = CHROMA_CLIENT.get_collection(name=collection_name)
|
|
|
query_embeddings = embedding_function(query)
|
|
|
+
|
|
|
result = collection.query(
|
|
|
query_embeddings=[query_embeddings],
|
|
|
n_results=k,
|
|
@@ -76,9 +77,9 @@ def query_doc_with_hybrid_search(
|
|
|
|
|
|
compressor = RerankCompressor(
|
|
|
embedding_function=embedding_function,
|
|
|
+ top_n=k,
|
|
|
reranking_function=reranking_function,
|
|
|
r_score=r,
|
|
|
- top_n=k,
|
|
|
)
|
|
|
|
|
|
compression_retriever = ContextualCompressionRetriever(
|
|
@@ -91,6 +92,7 @@ def query_doc_with_hybrid_search(
|
|
|
"documents": [[d.page_content for d in result]],
|
|
|
"metadatas": [[d.metadata for d in result]],
|
|
|
}
|
|
|
+
|
|
|
log.info(f"query_doc_with_hybrid_search:result {result}")
|
|
|
return result
|
|
|
except Exception as e:
|
|
@@ -167,7 +169,6 @@ def query_collection_with_hybrid_search(
|
|
|
reranking_function,
|
|
|
r: float,
|
|
|
):
|
|
|
-
|
|
|
results = []
|
|
|
for collection_name in collection_names:
|
|
|
try:
|
|
@@ -182,7 +183,6 @@ def query_collection_with_hybrid_search(
|
|
|
results.append(result)
|
|
|
except:
|
|
|
pass
|
|
|
-
|
|
|
return merge_and_sort_query_results(results, k=k, reverse=True)
|
|
|
|
|
|
|
|
@@ -443,13 +443,15 @@ class ChromaRetriever(BaseRetriever):
|
|
|
metadatas = results["metadatas"][0]
|
|
|
documents = results["documents"][0]
|
|
|
|
|
|
- return [
|
|
|
- Document(
|
|
|
- metadata=metadatas[idx],
|
|
|
- page_content=documents[idx],
|
|
|
+ results = []
|
|
|
+ for idx in range(len(ids)):
|
|
|
+ results.append(
|
|
|
+ Document(
|
|
|
+ metadata=metadatas[idx],
|
|
|
+ page_content=documents[idx],
|
|
|
+ )
|
|
|
)
|
|
|
- for idx in range(len(ids))
|
|
|
- ]
|
|
|
+ return results
|
|
|
|
|
|
|
|
|
import operator
|
|
@@ -465,9 +467,9 @@ from sentence_transformers import util
|
|
|
|
|
|
class RerankCompressor(BaseDocumentCompressor):
|
|
|
embedding_function: Any
|
|
|
+ top_n: int
|
|
|
reranking_function: Any
|
|
|
r_score: float
|
|
|
- top_n: int
|
|
|
|
|
|
class Config:
|
|
|
extra = Extra.forbid
|
|
@@ -479,7 +481,9 @@ class RerankCompressor(BaseDocumentCompressor):
|
|
|
query: str,
|
|
|
callbacks: Optional[Callbacks] = None,
|
|
|
) -> Sequence[Document]:
|
|
|
- if self.reranking_function:
|
|
|
+ reranking = self.reranking_function is not None
|
|
|
+
|
|
|
+ if reranking:
|
|
|
scores = self.reranking_function.predict(
|
|
|
[(query, doc.page_content) for doc in documents]
|
|
|
)
|
|
@@ -496,9 +500,7 @@ class RerankCompressor(BaseDocumentCompressor):
|
|
|
(d, s) for d, s in docs_with_scores if s >= self.r_score
|
|
|
]
|
|
|
|
|
|
- reverse = self.reranking_function is not None
|
|
|
- result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=reverse)
|
|
|
-
|
|
|
+ result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
|
|
|
final_results = []
|
|
|
for doc, doc_score in result[: self.top_n]:
|
|
|
metadata = doc.metadata
|