Browse Source

feat: show RAG query results as citations

Jun Siang Cheah 1 year ago
parent
commit
0872bea790

+ 9 - 1
backend/apps/rag/utils.py

@@ -320,11 +320,19 @@ def rag_messages(
         extracted_collections.extend(collection)
 
     context_string = ""
+    citations = []
     for context in relevant_contexts:
         try:
             if "documents" in context:
                 items = [item for item in context["documents"][0] if item is not None]
                 context_string += "\n\n".join(items)
+                if "metadatas" in context:
+                    citations.append(
+                        {
+                            "document": context["documents"][0],
+                            "metadata": context["metadatas"][0],
+                        }
+                    )
         except Exception as e:
             log.exception(e)
     context_string = context_string.strip()
@@ -355,7 +363,7 @@ def rag_messages(
 
     messages[last_user_message_idx] = new_user_message
 
-    return messages
+    return messages, citations
 
 
 def get_model_path(model: str, update_model: bool = False):

+ 36 - 3
backend/main.py

@@ -15,7 +15,7 @@ from fastapi.middleware.wsgi import WSGIMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 from starlette.exceptions import HTTPException as StarletteHTTPException
 from starlette.middleware.base import BaseHTTPMiddleware
-
+from starlette.responses import StreamingResponse
 
 from apps.ollama.main import app as ollama_app
 from apps.openai.main import app as openai_app
@@ -102,6 +102,8 @@ origins = ["*"]
 
 class RAGMiddleware(BaseHTTPMiddleware):
     async def dispatch(self, request: Request, call_next):
+        return_citations = False
+
         if request.method == "POST" and (
             "/api/chat" in request.url.path or "/chat/completions" in request.url.path
         ):
@@ -114,11 +116,15 @@ class RAGMiddleware(BaseHTTPMiddleware):
             # Parse string to JSON
             data = json.loads(body_str) if body_str else {}
 
+            return_citations = data.get("citations", False)
+            if "citations" in data:
+                del data["citations"]
+
             # Example: Add a new key-value pair or modify existing ones
             # data["modified"] = True  # Example modification
             if "docs" in data:
                 data = {**data}
-                data["messages"] = rag_messages(
+                data["messages"], citations = rag_messages(
                     docs=data["docs"],
                     messages=data["messages"],
                     template=rag_app.state.RAG_TEMPLATE,
@@ -130,7 +136,9 @@ class RAGMiddleware(BaseHTTPMiddleware):
                 )
                 del data["docs"]
 
-                log.debug(f"data['messages']: {data['messages']}")
+                log.debug(
+                    f"data['messages']: {data['messages']}, citations: {citations}"
+                )
 
             modified_body_bytes = json.dumps(data).encode("utf-8")
 
@@ -148,11 +156,36 @@ class RAGMiddleware(BaseHTTPMiddleware):
             ]
 
         response = await call_next(request)
+
+        if return_citations:
+            # Inject the citations into the response
+            if isinstance(response, StreamingResponse):
+                # If it's a streaming response, inject it as SSE event or NDJSON line
+                content_type = response.headers.get("Content-Type")
+                if "text/event-stream" in content_type:
+                    return StreamingResponse(
+                        self.openai_stream_wrapper(response.body_iterator, citations),
+                    )
+                if "application/x-ndjson" in content_type:
+                    return StreamingResponse(
+                        self.ollama_stream_wrapper(response.body_iterator, citations),
+                    )
+
         return response
 
     async def _receive(self, body: bytes):
         return {"type": "http.request", "body": body, "more_body": False}
 
+    async def openai_stream_wrapper(self, original_generator, citations):
+        yield f"data: {json.dumps({'citations': citations})}\n\n"
+        async for data in original_generator:
+            yield data
+
+    async def ollama_stream_wrapper(self, original_generator, citations):
+        yield f"{json.dumps({'citations': citations})}\n"
+        async for data in original_generator:
+            yield data
+
 
 app.add_middleware(RAGMiddleware)
 

+ 11 - 0
src/lib/apis/streaming/index.ts

@@ -4,6 +4,8 @@ import type { ParsedEvent } from 'eventsource-parser';
 type TextStreamUpdate = {
 	done: boolean;
 	value: string;
+	// eslint-disable-next-line @typescript-eslint/no-explicit-any
+	citations?: any;
 };
 
 // createOpenAITextStream takes a responseBody with a SSE response,
@@ -45,6 +47,11 @@ async function* openAIStreamToIterator(
 			const parsedData = JSON.parse(data);
 			console.log(parsedData);
 
+			if (parsedData.citations) {
+				yield { done: false, value: '', citations: parsedData.citations };
+				continue;
+			}
+
 			yield { done: false, value: parsedData.choices?.[0]?.delta?.content ?? '' };
 		} catch (e) {
 			console.error('Error extracting delta from SSE event:', e);
@@ -62,6 +69,10 @@ async function* streamLargeDeltasAsRandomChunks(
 			yield textStreamUpdate;
 			return;
 		}
+		if (textStreamUpdate.citations) {
+			yield textStreamUpdate;
+			continue;
+		}
 		let content = textStreamUpdate.value;
 		if (content.length < 5) {
 			yield { done: false, value: content };

+ 75 - 0
src/lib/components/chat/Messages/CitationsModal.svelte

@@ -0,0 +1,75 @@
+<script lang="ts">
+	import { getContext, onMount, tick } from 'svelte';
+
+	import Modal from '$lib/components/common/Modal.svelte';
+	const i18n = getContext('i18n');
+
+	export let show = false;
+	export let citation: any[];
+
+	let mergedDocuments = [];
+
+	onMount(async () => {
+		console.log(citation);
+		// Merge the document with its metadata
+		mergedDocuments = citation.document?.map((c, i) => {
+			return {
+				document: c,
+				metadata: citation.metadata?.[i]
+			};
+		});
+		console.log(mergedDocuments);
+	});
+</script>
+
+<Modal size="lg" bind:show>
+	<div>
+		<div class=" flex justify-between dark:text-gray-300 px-5 py-4">
+			<div class=" text-lg font-medium self-center capitalize">
+				{$i18n.t('Citation')}
+			</div>
+			<button
+				class="self-center"
+				on:click={() => {
+					show = false;
+				}}
+			>
+				<svg
+					xmlns="http://www.w3.org/2000/svg"
+					viewBox="0 0 20 20"
+					fill="currentColor"
+					class="w-5 h-5"
+				>
+					<path
+						d="M6.28 5.22a.75.75 0 00-1.06 1.06L8.94 10l-3.72 3.72a.75.75 0 101.06 1.06L10 11.06l3.72 3.72a.75.75 0 101.06-1.06L11.06 10l3.72-3.72a.75.75 0 00-1.06-1.06L10 8.94 6.28 5.22z"
+					/>
+				</svg>
+			</button>
+		</div>
+		<hr class=" dark:border-gray-850" />
+
+		<div class="flex flex-col w-full px-5 py-4 dark:text-gray-200 overflow-y-scroll max-h-[22rem]">
+			{#each mergedDocuments as document}
+				<!-- Source from document.metadata.source -->
+				<div class="flex flex-col w-full">
+					<div class="text-lg font-medium dark:text-gray-300">
+						{$i18n.t('Source')}
+					</div>
+					<div class="text-sm dark:text-gray-400">
+						{document.metadata.source}
+					</div>
+				</div>
+				<!-- Content from document.document.content -->
+				<div class="flex flex-col w-full">
+					<div class="text-lg font-medium dark:text-gray-300">
+						{$i18n.t('Content')}
+					</div>
+					<pre class="text-sm dark:text-gray-400">
+						{document.document}
+					</pre>
+				</div>
+				<hr class=" dark:border-gray-850" />
+			{/each}
+		</div>
+	</div>
+</Modal>

+ 73 - 32
src/lib/components/chat/Messages/ResponseMessage.svelte

@@ -32,6 +32,7 @@
 	import { WEBUI_BASE_URL } from '$lib/constants';
 	import Tooltip from '$lib/components/common/Tooltip.svelte';
 	import RateComment from './RateComment.svelte';
+	import CitationsModal from '$lib/components/chat/Messages/CitationsModal.svelte';
 
 	export let modelfiles = [];
 	export let message;
@@ -65,6 +66,8 @@
 
 	let showRateComment = false;
 
+	let showCitations = {};
+
 	$: tokens = marked.lexer(sanitizeResponseContent(message.content));
 
 	const renderer = new marked.Renderer();
@@ -360,6 +363,48 @@
 						{/each}
 					</div>
 				{/if}
+				{#if message.citations}
+					<div class="my-2.5 w-full flex overflow-x-auto gap-2 flex-wrap">
+						{#each message.citations as citation}
+							<div>
+								<CitationsModal bind:show={showCitations[citation]} {citation} />
+								<button
+									class="h-16 w-[15rem] flex items-center space-x-3 px-2.5 dark:bg-gray-600 rounded-xl border border-gray-200 dark:border-none text-left"
+									type="button"
+									on:click={() => {
+										showCitations[citation] = !showCitations[citation];
+									}}
+								>
+									<div class="p-2.5 bg-red-400 text-white rounded-lg">
+										<svg
+											xmlns="http://www.w3.org/2000/svg"
+											viewBox="0 0 24 24"
+											fill="currentColor"
+											class="w-6 h-6"
+										>
+											<path
+												fill-rule="evenodd"
+												d="M5.625 1.5c-1.036 0-1.875.84-1.875 1.875v17.25c0 1.035.84 1.875 1.875 1.875h12.75c1.035 0 1.875-.84 1.875-1.875V12.75A3.75 3.75 0 0 0 16.5 9h-1.875a1.875 1.875 0 0 1-1.875-1.875V5.25A3.75 3.75 0 0 0 9 1.5H5.625ZM7.5 15a.75.75 0 0 1 .75-.75h7.5a.75.75 0 0 1 0 1.5h-7.5A.75.75 0 0 1 7.5 15Zm.75 2.25a.75.75 0 0 0 0 1.5H12a.75.75 0 0 0 0-1.5H8.25Z"
+												clip-rule="evenodd"
+											/>
+											<path
+												d="M12.971 1.816A5.23 5.23 0 0 1 14.25 5.25v1.875c0 .207.168.375.375.375H16.5a5.23 5.23 0 0 1 3.434 1.279 9.768 9.768 0 0 0-6.963-6.963Z"
+											/>
+										</svg>
+									</div>
+
+									<div class="flex flex-col justify-center -space-y-0.5">
+										<div class=" dark:text-gray-100 text-sm font-medium line-clamp-1">
+											{citation.metadata?.[0]?.source ?? 'N/A'}
+										</div>
+
+										<div class=" text-gray-500 text-sm">{$i18n.t('Document')}</div>
+									</div>
+								</button>
+							</div>
+						{/each}
+					</div>
+				{/if}
 
 				<div
 					class="prose chat-{message.role} w-full max-w-full dark:prose-invert prose-headings:my-0 prose-p:m-0 prose-p:-mb-6 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-img:my-0 prose-ul:-my-4 prose-ol:-my-4 prose-li:-my-3 prose-ul:-mb-6 prose-ol:-mb-8 prose-ol:p-0 prose-li:-mb-4 whitespace-pre-line"
@@ -577,10 +622,11 @@
 														stroke-linejoin="round"
 														class="w-4 h-4"
 														xmlns="http://www.w3.org/2000/svg"
-														><path
-															d="M14 9V5a3 3 0 0 0-3-3l-4 9v11h11.28a2 2 0 0 0 2-1.7l1.38-9a2 2 0 0 0-2-2.3zM7 22H4a2 2 0 0 1-2-2v-7a2 2 0 0 1 2-2h3"
-														/></svg
 													>
+														<path
+															d="M14 9V5a3 3 0 0 0-3-3l-4 9v11h11.28a2 2 0 0 0 2-1.7l1.38-9a2 2 0 0 0-2-2.3zM7 22H4a2 2 0 0 1-2-2v-7a2 2 0 0 1 2-2h3"
+														/>
+													</svg>
 												</button>
 											</Tooltip>
 
@@ -611,10 +657,11 @@
 														stroke-linejoin="round"
 														class="w-4 h-4"
 														xmlns="http://www.w3.org/2000/svg"
-														><path
-															d="M10 15v4a3 3 0 0 0 3 3l4-9V2H5.72a2 2 0 0 0-2 1.7l-1.38 9a2 2 0 0 0 2 2.3zm7-13h2.67A2.31 2.31 0 0 1 22 4v7a2.31 2.31 0 0 1-2.33 2H17"
-														/></svg
 													>
+														<path
+															d="M10 15v4a3 3 0 0 0 3 3l4-9V2H5.72a2 2 0 0 0-2 1.7l-1.38 9a2 2 0 0 0 2 2.3zm7-13h2.67A2.31 2.31 0 0 1 22 4v7a2.31 2.31 0 0 1-2.33 2H17"
+														/>
+													</svg>
 												</button>
 											</Tooltip>
 										{/if}
@@ -637,35 +684,32 @@
 														fill="currentColor"
 														viewBox="0 0 24 24"
 														xmlns="http://www.w3.org/2000/svg"
-														><style>
+													>
+														<style>
 															.spinner_S1WN {
 																animation: spinner_MGfb 0.8s linear infinite;
 																animation-delay: -0.8s;
 															}
+
 															.spinner_Km9P {
 																animation-delay: -0.65s;
 															}
+
 															.spinner_JApP {
 																animation-delay: -0.5s;
 															}
+
 															@keyframes spinner_MGfb {
 																93.75%,
 																100% {
 																	opacity: 0.2;
 																}
 															}
-														</style><circle class="spinner_S1WN" cx="4" cy="12" r="3" /><circle
-															class="spinner_S1WN spinner_Km9P"
-															cx="12"
-															cy="12"
-															r="3"
-														/><circle
-															class="spinner_S1WN spinner_JApP"
-															cx="20"
-															cy="12"
-															r="3"
-														/></svg
-													>
+														</style>
+														<circle class="spinner_S1WN" cx="4" cy="12" r="3" />
+														<circle class="spinner_S1WN spinner_Km9P" cx="12" cy="12" r="3" />
+														<circle class="spinner_S1WN spinner_JApP" cx="20" cy="12" r="3" />
+													</svg>
 												{:else if speaking}
 													<svg
 														xmlns="http://www.w3.org/2000/svg"
@@ -718,35 +762,32 @@
 															fill="currentColor"
 															viewBox="0 0 24 24"
 															xmlns="http://www.w3.org/2000/svg"
-															><style>
+														>
+															<style>
 																.spinner_S1WN {
 																	animation: spinner_MGfb 0.8s linear infinite;
 																	animation-delay: -0.8s;
 																}
+
 																.spinner_Km9P {
 																	animation-delay: -0.65s;
 																}
+
 																.spinner_JApP {
 																	animation-delay: -0.5s;
 																}
+
 																@keyframes spinner_MGfb {
 																	93.75%,
 																	100% {
 																		opacity: 0.2;
 																	}
 																}
-															</style><circle class="spinner_S1WN" cx="4" cy="12" r="3" /><circle
-																class="spinner_S1WN spinner_Km9P"
-																cx="12"
-																cy="12"
-																r="3"
-															/><circle
-																class="spinner_S1WN spinner_JApP"
-																cx="20"
-																cy="12"
-																r="3"
-															/></svg
-														>
+															</style>
+															<circle class="spinner_S1WN" cx="4" cy="12" r="3" />
+															<circle class="spinner_S1WN spinner_Km9P" cx="12" cy="12" r="3" />
+															<circle class="spinner_S1WN spinner_JApP" cx="20" cy="12" r="3" />
+														</svg>
 													{:else}
 														<svg
 															xmlns="http://www.w3.org/2000/svg"

+ 15 - 3
src/routes/(app)/+page.svelte

@@ -366,7 +366,8 @@
 			},
 			format: $settings.requestFormat ?? undefined,
 			keep_alive: $settings.keepAlive ?? undefined,
-			docs: docs.length > 0 ? docs : undefined
+			docs: docs.length > 0 ? docs : undefined,
+			citations: docs.length > 0
 		});
 
 		if (res && res.ok) {
@@ -401,6 +402,11 @@
 							console.log(line);
 							let data = JSON.parse(line);
 
+							if ('citations' in data) {
+								responseMessage.citations = data.citations;
+								continue;
+							}
+
 							if ('detail' in data) {
 								throw data;
 							}
@@ -598,7 +604,8 @@
 				num_ctx: $settings?.options?.num_ctx ?? undefined,
 				frequency_penalty: $settings?.options?.repeat_penalty ?? undefined,
 				max_tokens: $settings?.options?.num_predict ?? undefined,
-				docs: docs.length > 0 ? docs : undefined
+				docs: docs.length > 0 ? docs : undefined,
+				citations: docs.length > 0
 			},
 			model?.source?.toLowerCase() === 'litellm'
 				? `${LITELLM_API_BASE_URL}/v1`
@@ -614,7 +621,7 @@
 			const textStream = await createOpenAITextStream(res.body, $settings.splitLargeChunks);
 
 			for await (const update of textStream) {
-				const { value, done } = update;
+				const { value, done, citations } = update;
 				if (done || stopResponseFlag || _chatId !== $chatId) {
 					responseMessage.done = true;
 					messages = messages;
@@ -626,6 +633,11 @@
 					break;
 				}
 
+				if (citations) {
+					responseMessage.citations = citations;
+					continue;
+				}
+
 				if (responseMessage.content == '' && value == '\n') {
 					continue;
 				} else {

+ 15 - 3
src/routes/(app)/c/[id]/+page.svelte

@@ -378,7 +378,8 @@
 			},
 			format: $settings.requestFormat ?? undefined,
 			keep_alive: $settings.keepAlive ?? undefined,
-			docs: docs.length > 0 ? docs : undefined
+			docs: docs.length > 0 ? docs : undefined,
+			citations: docs.length > 0
 		});
 
 		if (res && res.ok) {
@@ -413,6 +414,11 @@
 							console.log(line);
 							let data = JSON.parse(line);
 
+							if ('citations' in data) {
+								responseMessage.citations = data.citations;
+								continue;
+							}
+
 							if ('detail' in data) {
 								throw data;
 							}
@@ -610,7 +616,8 @@
 				num_ctx: $settings?.options?.num_ctx ?? undefined,
 				frequency_penalty: $settings?.options?.repeat_penalty ?? undefined,
 				max_tokens: $settings?.options?.num_predict ?? undefined,
-				docs: docs.length > 0 ? docs : undefined
+				docs: docs.length > 0 ? docs : undefined,
+				citations: docs.length > 0
 			},
 			model?.source?.toLowerCase() === 'litellm'
 				? `${LITELLM_API_BASE_URL}/v1`
@@ -626,7 +633,7 @@
 			const textStream = await createOpenAITextStream(res.body, $settings.splitLargeChunks);
 
 			for await (const update of textStream) {
-				const { value, done } = update;
+				const { value, done, citations } = update;
 				if (done || stopResponseFlag || _chatId !== $chatId) {
 					responseMessage.done = true;
 					messages = messages;
@@ -638,6 +645,11 @@
 					break;
 				}
 
+				if (citations) {
+					responseMessage.citations = citations;
+					continue;
+				}
+
 				if (responseMessage.content == '' && value == '\n') {
 					continue;
 				} else {