Browse Source

enh: inline citations

Timothy Jaeryang Baek 5 months ago
parent
commit
386c976e9a

+ 19 - 10
backend/open_webui/config.py

@@ -1181,21 +1181,30 @@ CHUNK_OVERLAP = PersistentConfig(
     int(os.environ.get("CHUNK_OVERLAP", "100")),
 )
 
-DEFAULT_RAG_TEMPLATE = """You are given a user query, some textual context and rules, all inside xml tags. You have to answer the query based on the context while respecting the rules.
+DEFAULT_RAG_TEMPLATE = """### Task:
+Respond to the user query using the provided context, incorporating inline citations in the format [source_id].
+
+### Guidelines:
+- If you don't know the answer, clearly state that.
+- If uncertain, ask the user for clarification.
+- Respond in the same language as the user's query.
+- If the context is unreadable or of poor quality, inform the user and provide the best possible answer.
+- If the answer isn't present in the context but you possess the knowledge, explain this to the user and provide the answer using your own understanding.
+- Include inline citations using [source_id] corresponding to the sources listed in the context.
+- Do not use XML tags in your response.
+- Ensure citations are concise and directly related to the information provided.
+
+### Example of Citation:
+If the user asks about a specific topic and the information is found in "whitepaper.pdf", the response should include the citation like so:  
+* "According to the study, the proposed method increases efficiency by 20% [whitepaper.pdf]."
+
+### Output:
+Provide a clear and direct response to the user's query, including inline citations in the format [source_id] where relevant.
 
 <context>
 {{CONTEXT}}
 </context>
 
-<rules>
-- If you don't know, just say so.
-- If you are not sure, ask for clarification.
-- Answer in the same language as the user query.
-- If the context appears unreadable or of poor quality, tell the user then answer as best as you can.
-- If the answer is not in the context but you think you know the answer, explain that to the user then answer with your own knowledge.
-- Answer directly and without using xml tags.
-</rules>
-
 <user_query>
 {{QUERY}}
 </user_query>

+ 7 - 1
backend/open_webui/main.py

@@ -679,7 +679,13 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
 
         # If context is not empty, insert it into the messages
         if len(contexts) > 0:
-            context_string = "/n".join(contexts).strip()
+            context_string = ""
+            for context_idx, context in enumerate(contexts):
+                print(context)
+                source_id = citations[context_idx].get("source", {}).get("name", "")
+                context_string += f"<source><source_id>{source_id}</source_id><source_context>{context}</source_context></source>\n"
+
+            context_string = context_string.strip()
             prompt = get_last_user_message(body["messages"])
 
             if prompt is None:

+ 1 - 0
src/lib/components/chat/Messages/Citations.svelte

@@ -94,6 +94,7 @@
 			<div class="flex text-xs font-medium">
 				{#each _citations as citation, idx}
 					<button
+						id={`source-${citation.source.name}`}
 						class="no-toggle outline-none flex dark:text-gray-300 p-1 bg-gray-50 hover:bg-gray-100 dark:bg-gray-900 dark:hover:bg-gray-850 transition rounded-xl max-w-96"
 						on:click={() => {
 							showCitationModal = true;

+ 4 - 0
src/lib/components/chat/Messages/ContentRenderer.svelte

@@ -11,9 +11,11 @@
 	export let id;
 	export let content;
 	export let model = null;
+	export let citations = null;
 
 	export let save = false;
 	export let floatingButtons = true;
+	export let onSourceClick = () => {};
 
 	let contentContainerElement;
 	let buttonsContainerElement;
@@ -129,6 +131,8 @@
 		{content}
 		{model}
 		{save}
+		sourceIds={(citations ?? []).map((c) => c?.source?.name)}
+		{onSourceClick}
 		on:update={(e) => {
 			dispatch('update', e.detail);
 		}}

+ 5 - 1
src/lib/components/chat/Messages/Markdown.svelte

@@ -16,6 +16,9 @@
 	export let model = null;
 	export let save = false;
 
+	export let sourceIds = [];
+	export let onSourceClick = () => {};
+
 	let tokens = [];
 
 	const options = {
@@ -28,7 +31,7 @@
 	$: (async () => {
 		if (content) {
 			tokens = marked.lexer(
-				replaceTokens(processResponseContent(content), model?.name, $user?.name)
+				replaceTokens(processResponseContent(content), sourceIds, model?.name, $user?.name)
 			);
 		}
 	})();
@@ -39,6 +42,7 @@
 		{tokens}
 		{id}
 		{save}
+		{onSourceClick}
 		on:update={(e) => {
 			dispatch('update', e.detail);
 		}}

+ 4 - 0
src/lib/components/chat/Messages/Markdown/MarkdownInlineTokens.svelte

@@ -12,9 +12,11 @@
 
 	import Image from '$lib/components/common/Image.svelte';
 	import KatexRenderer from './KatexRenderer.svelte';
+	import Source from './Source.svelte';
 
 	export let id: string;
 	export let tokens: Token[];
+	export let onSourceClick: Function = () => {};
 </script>
 
 {#each tokens as token}
@@ -26,6 +28,8 @@
 			{@html html}
 		{:else if token.text.includes(`<iframe src="${WEBUI_BASE_URL}/api/v1/files/`)}
 			{@html `${token.text}`}
+		{:else if token.text.includes(`<source_id`)}
+			<Source {token} onClick={onSourceClick} />
 		{:else}
 			{token.text}
 		{/if}

+ 15 - 4
src/lib/components/chat/Messages/Markdown/MarkdownTokens.svelte

@@ -25,6 +25,7 @@
 	export let top = true;
 
 	export let save = false;
+	export let onSourceClick: Function = () => {};
 
 	const headerComponent = (depth: number) => {
 		return 'h' + depth;
@@ -62,7 +63,7 @@
 		<hr />
 	{:else if token.type === 'heading'}
 		<svelte:element this={headerComponent(token.depth)}>
-			<MarkdownInlineTokens id={`${id}-${tokenIdx}-h`} tokens={token.tokens} />
+			<MarkdownInlineTokens id={`${id}-${tokenIdx}-h`} tokens={token.tokens} {onSourceClick} />
 		</svelte:element>
 	{:else if token.type === 'code'}
 		{#if token.raw.includes('```')}
@@ -108,6 +109,7 @@
 										<MarkdownInlineTokens
 											id={`${id}-${tokenIdx}-header-${headerIdx}`}
 											tokens={header.tokens}
+											{onSourceClick}
 										/>
 									</div>
 								</th>
@@ -126,6 +128,7 @@
 											<MarkdownInlineTokens
 												id={`${id}-${tokenIdx}-row-${rowIdx}-${cellIdx}`}
 												tokens={cell.tokens}
+												{onSourceClick}
 											/>
 										</div>
 									</td>
@@ -205,19 +208,27 @@
 		></iframe>
 	{:else if token.type === 'paragraph'}
 		<p>
-			<MarkdownInlineTokens id={`${id}-${tokenIdx}-p`} tokens={token.tokens ?? []} />
+			<MarkdownInlineTokens
+				id={`${id}-${tokenIdx}-p`}
+				tokens={token.tokens ?? []}
+				{onSourceClick}
+			/>
 		</p>
 	{:else if token.type === 'text'}
 		{#if top}
 			<p>
 				{#if token.tokens}
-					<MarkdownInlineTokens id={`${id}-${tokenIdx}-t`} tokens={token.tokens} />
+					<MarkdownInlineTokens id={`${id}-${tokenIdx}-t`} tokens={token.tokens} {onSourceClick} />
 				{:else}
 					{unescapeHtml(token.text)}
 				{/if}
 			</p>
 		{:else if token.tokens}
-			<MarkdownInlineTokens id={`${id}-${tokenIdx}-p`} tokens={token.tokens ?? []} />
+			<MarkdownInlineTokens
+				id={`${id}-${tokenIdx}-p`}
+				tokens={token.tokens ?? []}
+				{onSourceClick}
+			/>
 		{:else}
 			{unescapeHtml(token.text)}
 		{/if}

+ 23 - 0
src/lib/components/chat/Messages/Markdown/Source.svelte

@@ -0,0 +1,23 @@
+<script lang="ts">
+	export let token;
+	export let onClick: Function = () => {};
+
+	let id = '';
+	function extractDataAttribute(input) {
+		// Use a regular expression to extract the value of the `data` attribute
+		const match = input.match(/data="([^"]*)"/);
+		// Check if a match was found and return the first captured group
+		return match ? match[1] : null;
+	}
+
+	$: id = extractDataAttribute(token.text);
+</script>
+
+<button
+	class="text-xs font-medium px-1.5 py-0.5 dark:bg-white/5 dark:hover:bg-white/10 bg-black/5 hover:bg-black/10 transition rounded-lg"
+	on:click={() => {
+		onClick(id);
+	}}
+>
+	{id}
+</button>

+ 9 - 0
src/lib/components/chat/Messages/ResponseMessage.svelte

@@ -621,9 +621,18 @@
 									<ContentRenderer
 										id={message.id}
 										content={message.content}
+										citations={message.citations}
 										floatingButtons={message?.done}
 										save={!readOnly}
 										{model}
+										onSourceClick={(e) => {
+											console.log(e);
+											const sourceButton = document.getElementById(`source-${e}`);
+
+											if (sourceButton) {
+												sourceButton.click();
+											}
+										}}
 										on:update={(e) => {
 											const { raw, oldContent, newContent } = e.detail;
 

+ 1 - 5
src/lib/components/chat/Messages/UserMessage.svelte

@@ -5,11 +5,7 @@
 
 	import { models, settings } from '$lib/stores';
 	import { user as _user } from '$lib/stores';
-	import {
-		copyToClipboard as _copyToClipboard,
-		processResponseContent,
-		replaceTokens
-	} from '$lib/utils';
+	import { copyToClipboard as _copyToClipboard } from '$lib/utils';
 
 	import Name from './Name.svelte';
 	import ProfileImage from './ProfileImage.svelte';

+ 14 - 1
src/lib/utils/index.ts

@@ -8,12 +8,13 @@ import { TTS_RESPONSE_SPLIT } from '$lib/types';
 // Helper functions
 //////////////////////////
 
-export const replaceTokens = (content, char, user) => {
+export const replaceTokens = (content, sourceIds, char, user) => {
 	const charToken = /{{char}}/gi;
 	const userToken = /{{user}}/gi;
 	const videoIdToken = /{{VIDEO_FILE_ID_([a-f0-9-]+)}}/gi; // Regex to capture the video ID
 	const htmlIdToken = /{{HTML_FILE_ID_([a-f0-9-]+)}}/gi; // Regex to capture the HTML ID
 
+
 	// Replace {{char}} if char is provided
 	if (char !== undefined && char !== null) {
 		content = content.replace(charToken, char);
@@ -36,6 +37,18 @@ export const replaceTokens = (content, char, user) => {
 		return `<iframe src="${htmlUrl}" width="100%" frameborder="0" onload="this.style.height=(this.contentWindow.document.body.scrollHeight+20)+'px';"></iframe>`;
 	});
 
+
+	// Remove sourceIds from the content and replace them with <source_id>...</source_id>
+    if (Array.isArray(sourceIds)) {
+        sourceIds.forEach((sourceId) => {
+            // Create a token based on the exact `[sourceId]` string
+            const sourceToken = `\\[${sourceId}\\]`; // Escape special characters for RegExp
+            const sourceRegex = new RegExp(sourceToken, 'g'); // Match all occurrences of [sourceId]
+
+            content = content.replace(sourceRegex, `<source_id data="${sourceId}" />`);
+        });
+    }
+
 	return content;
 };