|
@@ -271,14 +271,14 @@ def rag_messages(
|
|
|
for doc in docs:
|
|
|
context = None
|
|
|
|
|
|
- collection = doc.get("collection_name")
|
|
|
- if collection:
|
|
|
- collection = [collection]
|
|
|
- else:
|
|
|
- collection = doc.get("collection_names", [])
|
|
|
+ collection_names = (
|
|
|
+ doc["collection_names"]
|
|
|
+ if doc["type"] == "collection"
|
|
|
+ else [doc["collection_name"]]
|
|
|
+ )
|
|
|
|
|
|
- collection = set(collection).difference(extracted_collections)
|
|
|
- if not collection:
|
|
|
+ collection_names = set(collection_names).difference(extracted_collections)
|
|
|
+ if not collection_names:
|
|
|
log.debug(f"skipping {doc} as it has already been extracted")
|
|
|
continue
|
|
|
|
|
@@ -288,11 +288,7 @@ def rag_messages(
|
|
|
else:
|
|
|
if hybrid_search:
|
|
|
context = query_collection_with_hybrid_search(
|
|
|
- collection_names=(
|
|
|
- doc["collection_names"]
|
|
|
- if doc["type"] == "collection"
|
|
|
- else [doc["collection_name"]]
|
|
|
- ),
|
|
|
+ collection_names=collection_names,
|
|
|
query=query,
|
|
|
embedding_function=embedding_function,
|
|
|
k=k,
|
|
@@ -301,11 +297,7 @@ def rag_messages(
|
|
|
)
|
|
|
else:
|
|
|
context = query_collection(
|
|
|
- collection_names=(
|
|
|
- doc["collection_names"]
|
|
|
- if doc["type"] == "collection"
|
|
|
- else [doc["collection_name"]]
|
|
|
- ),
|
|
|
+ collection_names=collection_names,
|
|
|
query=query,
|
|
|
embedding_function=embedding_function,
|
|
|
k=k,
|
|
@@ -315,9 +307,9 @@ def rag_messages(
|
|
|
context = None
|
|
|
|
|
|
if context:
|
|
|
- relevant_contexts.append(context)
|
|
|
+ relevant_contexts.append({**context, "source": doc})
|
|
|
|
|
|
- extracted_collections.extend(collection)
|
|
|
+ extracted_collections.extend(collection_names)
|
|
|
|
|
|
context_string = ""
|
|
|
|
|
@@ -325,11 +317,14 @@ def rag_messages(
|
|
|
for context in relevant_contexts:
|
|
|
try:
|
|
|
if "documents" in context:
|
|
|
- items = [item for item in context["documents"][0] if item is not None]
|
|
|
- context_string += "\n\n".join(items)
|
|
|
+ context_string += "\n\n".join(
|
|
|
+ [text for text in context["documents"][0] if text is not None]
|
|
|
+ )
|
|
|
+
|
|
|
if "metadatas" in context:
|
|
|
citations.append(
|
|
|
{
|
|
|
+ "source": context["source"],
|
|
|
"document": context["documents"][0],
|
|
|
"metadata": context["metadatas"][0],
|
|
|
}
|