Parcourir la source

Merge remote-tracking branch 'upstream/dev' into feat/backend-web-search

Jun Siang Cheah il y a 11 mois
Parent
commit
224a578e6b

+ 6 - 0
backend/apps/rag/main.py

@@ -28,6 +28,7 @@ from langchain_community.document_loaders import (
     UnstructuredXMLLoader,
     UnstructuredRSTLoader,
     UnstructuredExcelLoader,
+    UnstructuredPowerPointLoader,
     YoutubeLoader,
 )
 from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -823,6 +824,11 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
         "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
     ] or file_ext in ["xls", "xlsx"]:
         loader = UnstructuredExcelLoader(file_path)
+    elif file_content_type in [
+        "application/vnd.ms-powerpoint",
+        "application/vnd.openxmlformats-officedocument.presentationml.presentation",
+    ] or file_ext in ["ppt", "pptx"]:
+        loader = UnstructuredPowerPointLoader(file_path)
     elif file_ext in known_source_ext or (
         file_content_type and file_content_type.find("text/") >= 0
     ):

+ 1 - 0
backend/requirements.txt

@@ -35,6 +35,7 @@ chromadb==0.4.24
 sentence-transformers==2.7.0
 pypdf==4.2.0
 docx2txt==0.8
+python-pptx==0.6.23
 unstructured==0.11.8
 Markdown==3.6
 pypandoc==1.13

+ 1 - 1
src/lib/components/chat/Messages/CodeBlock.svelte

@@ -213,7 +213,7 @@ __builtins__.input = input`);
 			<div class="p-1">{@html lang}</div>
 
 			<div class="flex items-center">
-				{#if ['', 'python'].includes(lang) && (lang === 'python' || checkPythonCode(code))}
+				{#if lang === 'python' || (lang === '' && checkPythonCode(code))}
 					{#if executing}
 						<div class="copy-code-button bg-none border-none p-1 cursor-not-allowed">Running</div>
 					{:else}

+ 40 - 36
src/lib/components/chat/Messages/CompareMessages.svelte

@@ -41,6 +41,44 @@
 		};
 	}, {});
 
+	const showPreviousMessage = (model) => {
+		groupedMessagesIdx[model] = Math.max(0, groupedMessagesIdx[model] - 1);
+		let messageId = groupedMessages[model].messages[groupedMessagesIdx[model]].id;
+
+		console.log(messageId);
+		let messageChildrenIds = history.messages[messageId].childrenIds;
+
+		while (messageChildrenIds.length !== 0) {
+			messageId = messageChildrenIds.at(-1);
+			messageChildrenIds = history.messages[messageId].childrenIds;
+		}
+
+		history.currentId = messageId;
+
+		dispatch('change');
+	};
+
+	const showNextMessage = (model) => {
+		groupedMessagesIdx[model] = Math.min(
+			groupedMessages[model].messages.length - 1,
+			groupedMessagesIdx[model] + 1
+		);
+
+		let messageId = groupedMessages[model].messages[groupedMessagesIdx[model]].id;
+		console.log(messageId);
+
+		let messageChildrenIds = history.messages[messageId].childrenIds;
+
+		while (messageChildrenIds.length !== 0) {
+			messageId = messageChildrenIds.at(-1);
+			messageChildrenIds = history.messages[messageId].childrenIds;
+		}
+
+		history.currentId = messageId;
+
+		dispatch('change');
+	};
+
 	onMount(async () => {
 		await tick();
 		currentMessageId = messages[messageIdx].id;
@@ -97,42 +135,8 @@
 						isLastMessage={true}
 						{updateChatMessages}
 						{confirmEditResponseMessage}
-						showPreviousMessage={() => {
-							groupedMessagesIdx[model] = Math.max(0, groupedMessagesIdx[model] - 1);
-							let messageId = groupedMessages[model].messages[groupedMessagesIdx[model]].id;
-
-							console.log(messageId);
-							let messageChildrenIds = history.messages[messageId].childrenIds;
-
-							while (messageChildrenIds.length !== 0) {
-								messageId = messageChildrenIds.at(-1);
-								messageChildrenIds = history.messages[messageId].childrenIds;
-							}
-
-							history.currentId = messageId;
-
-							dispatch('change');
-						}}
-						showNextMessage={() => {
-							groupedMessagesIdx[model] = Math.min(
-								groupedMessages[model].messages.length - 1,
-								groupedMessagesIdx[model] + 1
-							);
-
-							let messageId = groupedMessages[model].messages[groupedMessagesIdx[model]].id;
-							console.log(messageId);
-
-							let messageChildrenIds = history.messages[messageId].childrenIds;
-
-							while (messageChildrenIds.length !== 0) {
-								messageId = messageChildrenIds.at(-1);
-								messageChildrenIds = history.messages[messageId].childrenIds;
-							}
-
-							history.currentId = messageId;
-
-							dispatch('change');
-						}}
+						showPreviousMessage={() => showPreviousMessage(model)}
+						showNextMessage={() => showNextMessage(model)}
 						{rateMessage}
 						{copyToClipboard}
 						{continueGeneration}

+ 2 - 1
src/lib/components/chat/Messages/ProfileImage.svelte

@@ -10,7 +10,8 @@
 		crossorigin="anonymous"
 		src={src.startsWith(WEBUI_BASE_URL) ||
 		src.startsWith('https://www.gravatar.com/avatar/') ||
-		src.startsWith('data:')
+		src.startsWith('data:') ||
+		src.startsWith('/')
 			? src
 			: `/user.png`}
 		class=" w-8 object-cover rounded-full"

+ 3 - 1
src/lib/constants.ts

@@ -86,7 +86,9 @@ export const SUPPORTED_FILE_EXTENSIONS = [
 	'csv',
 	'txt',
 	'xls',
-	'xlsx'
+	'xlsx',
+	'pptx',
+	'ppt'
 ];
 
 // Source: https://kit.svelte.dev/docs/modules#$env-static-public

+ 37 - 29
src/routes/(app)/+page.svelte

@@ -261,28 +261,6 @@
 	const sendPrompt = async (prompt, parentId, modelId = null) => {
 		const _chatId = JSON.parse(JSON.stringify($chatId));
 
-		let userContext = null;
-
-		if ($settings?.memory ?? false) {
-			const res = await queryMemory(localStorage.token, prompt).catch((error) => {
-				toast.error(error);
-				return null;
-			});
-
-			if (res) {
-				if (res.documents[0].length > 0) {
-					userContext = res.documents.reduce((acc, doc, index) => {
-						const createdAtTimestamp = res.metadatas[index][0].created_at;
-						const createdAtDate = new Date(createdAtTimestamp * 1000).toISOString().split('T')[0];
-						acc.push(`${index + 1}. [${createdAtDate}]. ${doc[0]}`);
-						return acc;
-					}, []);
-				}
-
-				console.log(userContext);
-			}
-		}
-
 		await Promise.all(
 			(modelId ? [modelId] : atSelectedModel !== '' ? [atSelectedModel.id] : selectedModels).map(
 				async (modelId) => {
@@ -299,7 +277,7 @@
 							role: 'assistant',
 							content: '',
 							model: model.id,
-							userContext: userContext,
+							userContext: null,
 							timestamp: Math.floor(Date.now() / 1000) // Unix epoch
 						};
 
@@ -315,6 +293,34 @@
 							];
 						}
 
+						await tick();
+
+						let userContext = null;
+						if ($settings?.memory ?? false) {
+							if (userContext === null) {
+								const res = await queryMemory(localStorage.token, prompt).catch((error) => {
+									toast.error(error);
+									return null;
+								});
+
+								if (res) {
+									if (res.documents[0].length > 0) {
+										userContext = res.documents.reduce((acc, doc, index) => {
+											const createdAtTimestamp = res.metadatas[index][0].created_at;
+											const createdAtDate = new Date(createdAtTimestamp * 1000)
+												.toISOString()
+												.split('T')[0];
+											acc.push(`${index + 1}. [${createdAtDate}]. ${doc[0]}`);
+											return acc;
+										}, []);
+									}
+
+									console.log(userContext);
+								}
+							}
+						}
+						responseMessage.userContext = userContext;
+
 						if (useWebSearch) {
 							await runWebSearchForPrompt(model.id, parentId, responseMessageId);
 						}
@@ -383,10 +389,11 @@
 			$settings.system || (responseMessage?.userContext ?? null)
 				? {
 						role: 'system',
-						content:
-							$settings.system + (responseMessage?.userContext ?? null)
-								? `\n\nUser Context:\n${responseMessage.userContext.join('\n')}`
+						content: `${$settings?.system ?? ''}${
+							responseMessage?.userContext ?? null
+								? `\n\nUser Context:\n${(responseMessage?.userContext ?? []).join('\n')}`
 								: ''
+						}`
 				  }
 				: undefined,
 			...messages
@@ -642,10 +649,11 @@
 						$settings.system || (responseMessage?.userContext ?? null)
 							? {
 									role: 'system',
-									content:
-										$settings.system + (responseMessage?.userContext ?? null)
-											? `\n\nUser Context:\n${responseMessage.userContext.join('\n')}`
+									content: `${$settings?.system ?? ''}${
+										responseMessage?.userContext ?? null
+											? `\n\nUser Context:\n${(responseMessage?.userContext ?? []).join('\n')}`
 											: ''
+									}`
 							  }
 							: undefined,
 						...messages

+ 37 - 29
src/routes/(app)/c/[id]/+page.svelte

@@ -268,28 +268,6 @@
 	const sendPrompt = async (prompt, parentId, modelId = null) => {
 		const _chatId = JSON.parse(JSON.stringify($chatId));
 
-		let userContext = null;
-
-		if ($settings?.memory ?? false) {
-			const res = await queryMemory(localStorage.token, prompt).catch((error) => {
-				toast.error(error);
-				return null;
-			});
-
-			if (res) {
-				if (res.documents[0].length > 0) {
-					userContext = res.documents.reduce((acc, doc, index) => {
-						const createdAtTimestamp = res.metadatas[index][0].created_at;
-						const createdAtDate = new Date(createdAtTimestamp * 1000).toISOString().split('T')[0];
-						acc.push(`${index + 1}. [${createdAtDate}]. ${doc[0]}`);
-						return acc;
-					}, []);
-				}
-
-				console.log(userContext);
-			}
-		}
-
 		await Promise.all(
 			(modelId ? [modelId] : atSelectedModel !== '' ? [atSelectedModel.id] : selectedModels).map(
 				async (modelId) => {
@@ -306,7 +284,7 @@
 							role: 'assistant',
 							content: '',
 							model: model.id,
-							userContext: userContext,
+							userContext: null,
 							timestamp: Math.floor(Date.now() / 1000) // Unix epoch
 						};
 
@@ -322,6 +300,34 @@
 							];
 						}
 
+						await tick();
+
+						let userContext = null;
+						if ($settings?.memory ?? false) {
+							if (userContext === null) {
+								const res = await queryMemory(localStorage.token, prompt).catch((error) => {
+									toast.error(error);
+									return null;
+								});
+
+								if (res) {
+									if (res.documents[0].length > 0) {
+										userContext = res.documents.reduce((acc, doc, index) => {
+											const createdAtTimestamp = res.metadatas[index][0].created_at;
+											const createdAtDate = new Date(createdAtTimestamp * 1000)
+												.toISOString()
+												.split('T')[0];
+											acc.push(`${index + 1}. [${createdAtDate}]. ${doc[0]}`);
+											return acc;
+										}, []);
+									}
+
+									console.log(userContext);
+								}
+							}
+						}
+						responseMessage.userContext = userContext;
+
 						if (useWebSearch) {
 							await runWebSearchForPrompt(model.id, parentId, responseMessageId);
 						}
@@ -390,10 +396,11 @@
 			$settings.system || (responseMessage?.userContext ?? null)
 				? {
 						role: 'system',
-						content:
-							$settings.system + (responseMessage?.userContext ?? null)
-								? `\n\nUser Context:\n${responseMessage.userContext.join('\n')}`
+						content: `${$settings?.system ?? ''}${
+							responseMessage?.userContext ?? null
+								? `\n\nUser Context:\n${(responseMessage?.userContext ?? []).join('\n')}`
 								: ''
+						}`
 				  }
 				: undefined,
 			...messages
@@ -649,10 +656,11 @@
 						$settings.system || (responseMessage?.userContext ?? null)
 							? {
 									role: 'system',
-									content:
-										$settings.system + (responseMessage?.userContext ?? null)
-											? `\n\nUser Context:\n${responseMessage.userContext.join('\n')}`
+									content: `${$settings?.system ?? ''}${
+										responseMessage?.userContext ?? null
+											? `\n\nUser Context:\n${(responseMessage?.userContext ?? []).join('\n')}`
 											: ''
+									}`
 							  }
 							: undefined,
 						...messages