Explorar el Código

refac: citations -> sources

Timothy Jaeryang Baek hace 5 meses
padre
commit
81386e9b04

+ 22 - 5
backend/open_webui/apps/retrieval/main.py

@@ -902,10 +902,11 @@ def process_file(
                 Document(
                     page_content=form_data.content,
                     metadata={
-                        "name": file.meta.get("name", file.filename),
+                        **file.meta,
+                        "name": file.filename,
                         "created_by": file.user_id,
                         "file_id": file.id,
-                        **file.meta,
+                        "source": file.filename,
                     },
                 )
             ]
@@ -932,10 +933,11 @@ def process_file(
                     Document(
                         page_content=file.data.get("content", ""),
                         metadata={
-                            "name": file.meta.get("name", file.filename),
+                            **file.meta,
+                            "name": file.filename,
                             "created_by": file.user_id,
                             "file_id": file.id,
-                            **file.meta,
+                            "source": file.filename,
                         },
                     )
                 ]
@@ -955,15 +957,30 @@ def process_file(
                 docs = loader.load(
                     file.filename, file.meta.get("content_type"), file_path
                 )
+
+                docs = [
+                    Document(
+                        page_content=doc.page_content,
+                        metadata={
+                            **doc.metadata,
+                            "name": file.filename,
+                            "created_by": file.user_id,
+                            "file_id": file.id,
+                            "source": file.filename,
+                        },
+                    )
+                    for doc in docs
+                ]
             else:
                 docs = [
                     Document(
                         page_content=file.data.get("content", ""),
                         metadata={
+                            **file.meta,
                             "name": file.filename,
                             "created_by": file.user_id,
                             "file_id": file.id,
-                            **file.meta,
+                            "source": file.filename,
                         },
                     )
                 ]

+ 7 - 26
backend/open_webui/apps/retrieval/utils.py

@@ -307,7 +307,7 @@ def get_embedding_function(
         return lambda query: generate_multiple(query, func)
 
 
-def get_rag_context(
+def get_sources_from_files(
     files,
     queries,
     embedding_function,
@@ -387,43 +387,24 @@ def get_rag_context(
                 del file["data"]
             relevant_contexts.append({**context, "file": file})
 
-    contexts = []
-    citations = []
+    sources = []
     for context in relevant_contexts:
         try:
             if "documents" in context:
-                file_names = list(
-                    set(
-                        [
-                            metadata["name"]
-                            for metadata in context["metadatas"][0]
-                            if metadata is not None and "name" in metadata
-                        ]
-                    )
-                )
-                contexts.append(
-                    ((", ".join(file_names) + ":\n\n") if file_names else "")
-                    + "\n\n".join(
-                        [text for text in context["documents"][0] if text is not None]
-                    )
-                )
-
                 if "metadatas" in context:
-                    citation = {
+                    source = {
                         "source": context["file"],
                         "document": context["documents"][0],
                         "metadata": context["metadatas"][0],
                     }
                     if "distances" in context and context["distances"]:
-                        citation["distances"] = context["distances"][0]
-                    citations.append(citation)
+                        source["distances"] = context["distances"][0]
+
+                    sources.append(source)
         except Exception as e:
             log.exception(e)
 
-    print("contexts", contexts)
-    print("citations", citations)
-
-    return contexts, citations
+    return sources
 
 
 def get_model_path(model: str, update_model: bool = False):

+ 1 - 1
backend/open_webui/apps/webui/routers/files.py

@@ -56,7 +56,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
             FileForm(
                 **{
                     "id": id,
-                    "filename": filename,
+                    "filename": name,
                     "path": file_path,
                     "meta": {
                         "name": name,

+ 60 - 46
backend/open_webui/main.py

@@ -49,7 +49,7 @@ from open_webui.apps.openai.main import (
     get_all_models_responses as get_openai_models_responses,
 )
 from open_webui.apps.retrieval.main import app as retrieval_app
-from open_webui.apps.retrieval.utils import get_rag_context, rag_template
+from open_webui.apps.retrieval.utils import get_sources_from_files, rag_template
 from open_webui.apps.socket.main import (
     app as socket_app,
     periodic_usage_pool_cleanup,
@@ -380,8 +380,7 @@ async def chat_completion_tools_handler(
         return body, {}
 
     skip_files = False
-    contexts = []
-    citations = []
+    sources = []
 
     task_model_id = get_task_model_id(
         body["model"],
@@ -465,24 +464,37 @@ async def chat_completion_tools_handler(
 
             print(tools[tool_function_name]["citation"])
 
-            if tools[tool_function_name]["citation"]:
-                citations.append(
-                    {
-                        "source": {
-                            "name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
-                        },
-                        "document": [tool_output],
-                        "metadata": [{"source": tool_function_name}],
-                    }
-                )
-            else:
-                citations.append({})
+            if isinstance(tool_output, str):
+                if tools[tool_function_name]["citation"]:
+                    sources.append(
+                        {
+                            "source": {
+                                "name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
+                            },
+                            "document": [tool_output],
+                            "metadata": [
+                                {
+                                    "source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
+                                }
+                            ],
+                        }
+                    )
+                else:
+                    sources.append(
+                        {
+                            "source": {},
+                            "document": [tool_output],
+                            "metadata": [
+                                {
+                                    "source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
+                                }
+                            ],
+                        }
+                    )
 
-            if tools[tool_function_name]["file_handler"]:
-                skip_files = True
+                if tools[tool_function_name]["file_handler"]:
+                    skip_files = True
 
-            if isinstance(tool_output, str):
-                contexts.append(tool_output)
         except Exception as e:
             log.exception(f"Error: {e}")
             content = None
@@ -490,19 +502,18 @@ async def chat_completion_tools_handler(
         log.exception(f"Error: {e}")
         content = None
 
-    log.debug(f"tool_contexts: {contexts} {citations}")
+    log.debug(f"tool_contexts: {sources}")
 
     if skip_files and "files" in body.get("metadata", {}):
         del body["metadata"]["files"]
 
-    return body, {"contexts": contexts, "citations": citations}
+    return body, {"sources": sources}
 
 
 async def chat_completion_files_handler(
     body: dict, user: UserModel
 ) -> tuple[dict, dict[str, list]]:
-    contexts = []
-    citations = []
+    sources = []
 
     try:
         queries_response = await generate_queries(
@@ -530,7 +541,7 @@ async def chat_completion_files_handler(
     print(f"{queries=}")
 
     if files := body.get("metadata", {}).get("files", None):
-        contexts, citations = get_rag_context(
+        sources = get_sources_from_files(
             files=files,
             queries=queries,
             embedding_function=retrieval_app.state.EMBEDDING_FUNCTION,
@@ -540,9 +551,8 @@ async def chat_completion_files_handler(
             hybrid_search=retrieval_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
         )
 
-        log.debug(f"rag_contexts: {contexts}, citations: {citations}")
-
-    return body, {"contexts": contexts, "citations": citations}
+        log.debug(f"rag_contexts:sources: {sources}")
+    return body, {"sources": sources}
 
 
 def is_chat_completion_request(request):
@@ -643,8 +653,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
         # Initialize data_items to store additional data to be sent to the client
         # Initialize contexts and citation
         data_items = []
-        contexts = []
-        citations = []
+        sources = []
 
         try:
             body, flags = await chat_completion_filter_functions_handler(
@@ -670,32 +679,34 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
             body, flags = await chat_completion_tools_handler(
                 body, user, models, extra_params
             )
-            contexts.extend(flags.get("contexts", []))
-            citations.extend(flags.get("citations", []))
+            sources.extend(flags.get("sources", []))
         except Exception as e:
             log.exception(e)
 
         try:
             body, flags = await chat_completion_files_handler(body, user)
-            contexts.extend(flags.get("contexts", []))
-            citations.extend(flags.get("citations", []))
+            sources.extend(flags.get("sources", []))
         except Exception as e:
             log.exception(e)
 
         # If context is not empty, insert it into the messages
-        if len(contexts) > 0:
+        if len(sources) > 0:
             context_string = ""
-            for context_idx, context in enumerate(contexts):
-                print(context)
-                source_id = citations[context_idx].get("source", {}).get("name", "")
+            for source_idx, source in enumerate(sources):
+                source_id = source.get("source", {}).get("name", "")
 
-                print(f"\n\n\n\n{source_id}\n\n\n\n")
-                if source_id:
-                    context_string += f"<source><source_id>{source_id}</source_id><source_context>{context}</source_context></source>\n"
-                else:
-                    context_string += (
-                        f"<source><source_context>{context}</source_context></source>\n"
-                    )
+                if "document" in source:
+                    for doc_idx, doc_context in enumerate(source["document"]):
+                        metadata = source.get("metadata")
+
+                        if metadata:
+                            doc_source_id = metadata[doc_idx].get("source", source_id)
+
+                        if source_id:
+                            context_string += f"<source><source_id>{doc_source_id}</source_id><source_context>{doc_context}</source_context></source>\n"
+                        else:
+                            # If there is no source_id, then do not include the source_id tag
+                            context_string += f"<source><source_context>{doc_context}</source_context></source>\n"
 
             context_string = context_string.strip()
             prompt = get_last_user_message(body["messages"])
@@ -728,8 +739,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                 )
 
         # If there are citations, add them to the data_items
-        if len(citations) > 0:
-            data_items.append({"citations": citations})
+        sources = [
+            source for source in sources if source.get("source", {}).get("name", "")
+        ]
+        if len(sources) > 0:
+            data_items.append({"sources": sources})
 
         modified_body_bytes = json.dumps(body).encode("utf-8")
         # Replace the request body with the modified one

+ 4 - 4
src/lib/apis/streaming/index.ts

@@ -5,7 +5,7 @@ type TextStreamUpdate = {
 	done: boolean;
 	value: string;
 	// eslint-disable-next-line @typescript-eslint/no-explicit-any
-	citations?: any;
+	sources?: any;
 	// eslint-disable-next-line @typescript-eslint/no-explicit-any
 	selectedModelId?: any;
 	error?: any;
@@ -67,8 +67,8 @@ async function* openAIStreamToIterator(
 				break;
 			}
 
-			if (parsedData.citations) {
-				yield { done: false, value: '', citations: parsedData.citations };
+			if (parsedData.sources) {
+				yield { done: false, value: '', sources: parsedData.sources };
 				continue;
 			}
 
@@ -98,7 +98,7 @@ async function* streamLargeDeltasAsRandomChunks(
 			yield textStreamUpdate;
 			return;
 		}
-		if (textStreamUpdate.citations) {
+		if (textStreamUpdate.sources) {
 			yield textStreamUpdate;
 			continue;
 		}

+ 11 - 11
src/lib/components/chat/Chat.svelte

@@ -236,10 +236,10 @@
 					message.code_executions = message.code_executions;
 				} else {
 					// Regular citation.
-					if (message?.citations) {
-						message.citations.push(data);
+					if (message?.sources) {
+						message.sources.push(data);
 					} else {
-						message.citations = [data];
+						message.sources = [data];
 					}
 				}
 			} else if (type === 'message') {
@@ -664,7 +664,7 @@
 				content: m.content,
 				info: m.info ? m.info : undefined,
 				timestamp: m.timestamp,
-				...(m.citations ? { citations: m.citations } : {})
+				...(m.sources ? { sources: m.sources } : {})
 			})),
 			chat_id: chatId,
 			session_id: $socket?.id,
@@ -718,7 +718,7 @@
 				content: m.content,
 				info: m.info ? m.info : undefined,
 				timestamp: m.timestamp,
-				...(m.citations ? { citations: m.citations } : {})
+				...(m.sources ? { sources: m.sources } : {})
 			})),
 			...(event ? { event: event } : {}),
 			chat_id: chatId,
@@ -1278,8 +1278,8 @@
 								console.log(line);
 								let data = JSON.parse(line);
 
-								if ('citations' in data) {
-									responseMessage.citations = data.citations;
+								if ('sources' in data) {
+									responseMessage.sources = data.sources;
 									// Only remove status if it was initially set
 									if (model?.info?.meta?.knowledge ?? false) {
 										responseMessage.statusHistory = responseMessage.statusHistory.filter(
@@ -1632,7 +1632,7 @@
 					const textStream = await createOpenAITextStream(res.body, $settings.splitLargeChunks);
 
 					for await (const update of textStream) {
-						const { value, done, citations, selectedModelId, error, usage } = update;
+						const { value, done, sources, selectedModelId, error, usage } = update;
 						if (error) {
 							await handleOpenAIError(error, null, model, responseMessage);
 							break;
@@ -1658,8 +1658,8 @@
 							continue;
 						}
 
-						if (citations) {
-							responseMessage.citations = citations;
+						if (sources) {
+							responseMessage.sources = sources;
 							// Only remove status if it was initially set
 							if (model?.info?.meta?.knowledge ?? false) {
 								responseMessage.statusHistory = responseMessage.statusHistory.filter(
@@ -1938,7 +1938,7 @@
 			if (res && res.ok && res.body) {
 				const textStream = await createOpenAITextStream(res.body, $settings.splitLargeChunks);
 				for await (const update of textStream) {
-					const { value, done, citations, error, usage } = update;
+					const { value, done, sources, error, usage } = update;
 					if (error || done) {
 						break;
 					}

+ 28 - 26
src/lib/components/chat/Messages/Citations.svelte

@@ -7,9 +7,9 @@
 
 	const i18n = getContext('i18n');
 
-	export let citations = [];
+	export let sources = [];
 
-	let _citations = [];
+	let citations = [];
 	let showPercentage = false;
 	let showRelevance = true;
 
@@ -17,8 +17,8 @@
 	let selectedCitation: any = null;
 	let isCollapsibleOpen = false;
 
-	function calculateShowRelevance(citations: any[]) {
-		const distances = citations.flatMap((citation) => citation.distances ?? []);
+	function calculateShowRelevance(sources: any[]) {
+		const distances = sources.flatMap((citation) => citation.distances ?? []);
 		const inRange = distances.filter((d) => d !== undefined && d >= -1 && d <= 1).length;
 		const outOfRange = distances.filter((d) => d !== undefined && (d < -1 || d > 1)).length;
 
@@ -36,29 +36,31 @@
 		return true;
 	}
 
-	function shouldShowPercentage(citations: any[]) {
-		const distances = citations.flatMap((citation) => citation.distances ?? []);
+	function shouldShowPercentage(sources: any[]) {
+		const distances = sources.flatMap((citation) => citation.distances ?? []);
 		return distances.every((d) => d !== undefined && d >= -1 && d <= 1);
 	}
 
 	$: {
-		_citations = citations.reduce((acc, citation) => {
-			if (Object.keys(citation).length === 0) {
+		citations = sources.reduce((acc, source) => {
+			if (Object.keys(source).length === 0) {
 				return acc;
 			}
 
-			citation.document.forEach((document, index) => {
-				const metadata = citation.metadata?.[index];
-				const distance = citation.distances?.[index];
+			source.document.forEach((document, index) => {
+				const metadata = source.metadata?.[index];
+				const distance = source.distances?.[index];
+
+				// Within the same citation there could be multiple documents
 				const id = metadata?.source ?? 'N/A';
-				let source = citation?.source;
+				let _source = source?.source;
 
 				if (metadata?.name) {
-					source = { ...source, name: metadata.name };
+					_source = { ..._source, name: metadata.name };
 				}
 
 				if (id.startsWith('http://') || id.startsWith('https://')) {
-					source = { ...source, name: id, url: id };
+					_source = { ..._source, name: id, url: id };
 				}
 
 				const existingSource = acc.find((item) => item.id === id);
@@ -70,7 +72,7 @@
 				} else {
 					acc.push({
 						id: id,
-						source: source,
+						source: _source,
 						document: [document],
 						metadata: metadata ? [metadata] : [],
 						distances: distance !== undefined ? [distance] : undefined
@@ -80,8 +82,8 @@
 			return acc;
 		}, []);
 
-		showRelevance = calculateShowRelevance(_citations);
-		showPercentage = shouldShowPercentage(_citations);
+		showRelevance = calculateShowRelevance(citations);
+		showPercentage = shouldShowPercentage(citations);
 	}
 </script>
 
@@ -92,11 +94,11 @@
 	{showRelevance}
 />
 
-{#if _citations.length > 0}
+{#if citations.length > 0}
 	<div class=" py-0.5 -mx-0.5 w-full flex gap-1 items-center flex-wrap">
-		{#if _citations.length <= 3}
+		{#if citations.length <= 3}
 			<div class="flex text-xs font-medium">
-				{#each _citations as citation, idx}
+				{#each citations as citation, idx}
 					<button
 						id={`source-${citation.source.name}`}
 						class="no-toggle outline-none flex dark:text-gray-300 p-1 bg-white dark:bg-gray-900 rounded-xl max-w-96"
@@ -105,7 +107,7 @@
 							selectedCitation = citation;
 						}}
 					>
-						{#if _citations.every((c) => c.distances !== undefined)}
+						{#if citations.every((c) => c.distances !== undefined)}
 							<div class="bg-gray-50 dark:bg-gray-800 rounded-full size-4">
 								{idx + 1}
 							</div>
@@ -127,7 +129,7 @@
 						<span class="whitespace-nowrap hidden sm:inline">{$i18n.t('References from')}</span>
 						<div class="flex items-center">
 							<div class="flex text-xs font-medium items-center">
-								{#each _citations.slice(0, 2) as citation, idx}
+								{#each citations.slice(0, 2) as citation, idx}
 									<button
 										class="no-toggle outline-none flex dark:text-gray-300 p-1 bg-gray-50 hover:bg-gray-100 dark:bg-gray-900 dark:hover:bg-gray-850 transition rounded-xl max-w-96"
 										on:click={() => {
@@ -138,7 +140,7 @@
 											e.stopPropagation();
 										}}
 									>
-										{#if _citations.every((c) => c.distances !== undefined)}
+										{#if citations.every((c) => c.distances !== undefined)}
 											<div class="bg-gray-50 dark:bg-gray-800 rounded-full size-4">
 												{idx + 1}
 											</div>
@@ -152,7 +154,7 @@
 						</div>
 						<div class="flex items-center gap-1 whitespace-nowrap">
 							<span class="hidden sm:inline">{$i18n.t('and')}</span>
-							{_citations.length - 2}
+							{citations.length - 2}
 							<span>{$i18n.t('more')}</span>
 						</div>
 					</div>
@@ -166,7 +168,7 @@
 				</div>
 				<div slot="content">
 					<div class="flex text-xs font-medium">
-						{#each _citations as citation, idx}
+						{#each citations as citation, idx}
 							<button
 								class="no-toggle outline-none flex dark:text-gray-300 p-1 bg-gray-50 hover:bg-gray-100 dark:bg-gray-900 dark:hover:bg-gray-850 transition rounded-xl max-w-96"
 								on:click={() => {
@@ -174,7 +176,7 @@
 									selectedCitation = citation;
 								}}
 							>
-								{#if _citations.every((c) => c.distances !== undefined)}
+								{#if citations.every((c) => c.distances !== undefined)}
 									<div class="bg-gray-50 dark:bg-gray-800 rounded-full size-4">
 										{idx + 1}
 									</div>

+ 27 - 2
src/lib/components/chat/Messages/ContentRenderer.svelte

@@ -7,11 +7,12 @@
 	import LightBlub from '$lib/components/icons/LightBlub.svelte';
 	import { chatId, mobile, showArtifacts, showControls, showOverview } from '$lib/stores';
 	import ChatBubble from '$lib/components/icons/ChatBubble.svelte';
+	import { stringify } from 'postcss';
 
 	export let id;
 	export let content;
 	export let model = null;
-	export let citations = null;
+	export let sources = null;
 
 	export let save = false;
 	export let floatingButtons = true;
@@ -131,7 +132,31 @@
 		{content}
 		{model}
 		{save}
-		sourceIds={(citations ?? []).map((c) => c?.source?.name)}
+		sourceIds={(sources ?? []).reduce((acc, s) => {
+			let ids = [];
+			s.document.forEach((document, index) => {
+				const metadata = s.metadata?.[index];
+				const id = metadata?.source ?? 'N/A';
+
+				if (metadata?.name) {
+					ids.push(metadata.name);
+					return ids;
+				}
+
+				if (id.startsWith('http://') || id.startsWith('https://')) {
+					ids.push(id);
+				} else {
+					ids.push(s?.source?.name ?? id);
+				}
+
+				return ids;
+			});
+
+			acc = [...acc, ...ids];
+
+			// remove duplicates
+			return acc.filter((item, index) => acc.indexOf(item) === index);
+		}, [])}
 		{onSourceClick}
 		on:update={(e) => {
 			dispatch('update', e.detail);

+ 4 - 4
src/lib/components/chat/Messages/ResponseMessage.svelte

@@ -64,7 +64,7 @@
 		};
 		done: boolean;
 		error?: boolean | { content: string };
-		citations?: string[];
+		sources?: string[];
 		code_executions?: {
 			uuid: string;
 			name: string;
@@ -621,7 +621,7 @@
 									<ContentRenderer
 										id={message.id}
 										content={message.content}
-										citations={message.citations}
+										sources={message.sources}
 										floatingButtons={message?.done}
 										save={!readOnly}
 										{model}
@@ -662,8 +662,8 @@
 									<Error content={message?.error?.content ?? message.content} />
 								{/if}
 
-								{#if message.citations && (model?.info?.meta?.capabilities?.citations ?? true)}
-									<Citations citations={message.citations} />
+								{#if (message?.sources || message?.citations) && (model?.info?.meta?.capabilities?.citations ?? true)}
+									<Citations sources={message?.sources ?? message?.citations} />
 								{/if}
 
 								{#if message.code_executions}