Browse Source

refac: include source name to citation

Timothy J. Baek 1 year ago
parent
commit
64ed0d1089

+ 16 - 21
backend/apps/rag/utils.py

@@ -271,14 +271,14 @@ def rag_messages(
     for doc in docs:
         context = None
 
-        collection = doc.get("collection_name")
-        if collection:
-            collection = [collection]
-        else:
-            collection = doc.get("collection_names", [])
+        collection_names = (
+            doc["collection_names"]
+            if doc["type"] == "collection"
+            else [doc["collection_name"]]
+        )
 
-        collection = set(collection).difference(extracted_collections)
-        if not collection:
+        collection_names = set(collection_names).difference(extracted_collections)
+        if not collection_names:
             log.debug(f"skipping {doc} as it has already been extracted")
             continue
 
@@ -288,11 +288,7 @@ def rag_messages(
             else:
                 if hybrid_search:
                     context = query_collection_with_hybrid_search(
-                        collection_names=(
-                            doc["collection_names"]
-                            if doc["type"] == "collection"
-                            else [doc["collection_name"]]
-                        ),
+                        collection_names=collection_names,
                         query=query,
                         embedding_function=embedding_function,
                         k=k,
@@ -301,11 +297,7 @@ def rag_messages(
                     )
                 else:
                     context = query_collection(
-                        collection_names=(
-                            doc["collection_names"]
-                            if doc["type"] == "collection"
-                            else [doc["collection_name"]]
-                        ),
+                        collection_names=collection_names,
                         query=query,
                         embedding_function=embedding_function,
                         k=k,
@@ -315,9 +307,9 @@ def rag_messages(
             context = None
 
         if context:
-            relevant_contexts.append(context)
+            relevant_contexts.append({**context, "source": doc})
 
-        extracted_collections.extend(collection)
+        extracted_collections.extend(collection_names)
 
     context_string = ""
 
@@ -325,11 +317,14 @@ def rag_messages(
     for context in relevant_contexts:
         try:
             if "documents" in context:
-                items = [item for item in context["documents"][0] if item is not None]
-                context_string += "\n\n".join(items)
+                context_string += "\n\n".join(
+                    [text for text in context["documents"][0] if text is not None]
+                )
+
                 if "metadatas" in context:
                     citations.append(
                         {
+                            "source": context["source"],
                             "document": context["documents"][0],
                             "metadata": context["metadatas"][0],
                         }

+ 2 - 2
src/lib/components/chat/Messages/CitationsModal.svelte

@@ -10,10 +10,10 @@
 	let mergedDocuments = [];
 
 	onMount(async () => {
-		console.log(citation);
 		// Merge the document with its metadata
 		mergedDocuments = citation.document?.map((c, i) => {
 			return {
+				source: citation.source,
 				document: c,
 				metadata: citation.metadata?.[i]
 			};
@@ -54,7 +54,7 @@
 							{$i18n.t('Source')}
 						</div>
 						<div class="text-sm dark:text-gray-400">
-							{document.metadata?.source ?? $i18n.t('No source available')}
+							{document.source?.name ?? $i18n.t('No source available')}
 						</div>
 					</div>
 					<div class="flex flex-col w-full">

+ 20 - 30
src/lib/components/chat/Messages/ResponseMessage.svelte

@@ -66,9 +66,8 @@
 
 	let showRateComment = false;
 
-	let showCitations = {};
 	// Backend returns a list of citations per collection, we flatten it to citations per source
-	let flattenedCitations = {};
+	let citations = {};
 
 	$: tokens = marked.lexer(sanitizeResponseContent(message.content));
 
@@ -137,27 +136,21 @@
 		}
 
 		if (message.citations) {
-			for (const citation of message.citations) {
-				const zipped = (citation?.document ?? []).map(function (document, index) {
-					return [document, citation.metadata?.[index]];
+			message.citations.forEach((citation) => {
+				citation.document.forEach((document, index) => {
+					const metadata = citation.metadata?.[index];
+					const source = citation?.source?.name ?? metadata?.source ?? 'N/A';
+
+					citations[source] = citations[source] || {
+						source: citation.source,
+						document: [],
+						metadata: []
+					};
+
+					citations[source].document.push(document);
+					citations[source].metadata.push(metadata);
 				});
-
-				for (const [document, metadata] of zipped) {
-					const source = metadata?.source ?? 'N/A';
-					if (source in flattenedCitations) {
-						flattenedCitations[source].document.push(document);
-						flattenedCitations[source].metadata.push(metadata);
-					} else {
-						flattenedCitations[source] = {
-							document: [document],
-							metadata: [metadata]
-						};
-					}
-				}
-			}
-
-			console.log(flattenedCitations);
-			console.log(Object.keys(flattenedCitations));
+			});
 		}
 	};
 
@@ -474,15 +467,12 @@
 					</div>
 				</div>
 
-				{#if Object.keys(flattenedCitations).length > 0}
+				{#if Object.keys(citations).length > 0}
 					<hr class="  dark:border-gray-800" />
 
 					<div class="my-2.5 w-full flex overflow-x-auto gap-2 flex-wrap">
-						{#each [...Object.keys(flattenedCitations)] as source, idx}
-							<CitationsModal
-								bind:show={showCitations[source]}
-								citation={flattenedCitations[source]}
-							/>
+						{#each Object.keys(citations) as source, idx}
+							<CitationsModal bind:show={citations[source].show} citation={citations[source]} />
 
 							<div class="flex gap-1 text-xs font-semibold">
 								<div>
@@ -492,10 +482,10 @@
 								<button
 									class="dark:text-gray-500 underline"
 									on:click={() => {
-										showCitations[source] = !showCitations[source];
+										citations[source].show = !citations[source].show;
 									}}
 								>
-									{source}
+									{citations[source].source.name}
 								</button>
 							</div>
 						{/each}