瀏覽代碼

feat: memory integration

Timothy J. Baek 11 月之前
父節點
當前提交
febab58821

+ 1 - 1
backend/apps/web/routers/memories.py

@@ -71,7 +71,7 @@ class QueryMemoryForm(BaseModel):
     content: str
 
 
-@router.post("/query", response_model=Optional[MemoryModel])
+@router.post("/query")
 async def query_memory(
     request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user)
 ):

+ 1 - 1
src/lib/components/chat/Settings/Personalization/AddMemoryModal.svelte

@@ -26,8 +26,8 @@
 		if (res) {
 			console.log(res);
 			toast.success('Memory added successfully');
+			content = '';
 			show = false;
-
 			dispatch('save');
 		}
 

+ 32 - 4
src/routes/(app)/+page.svelte

@@ -41,6 +41,7 @@
 	import { LITELLM_API_BASE_URL, OLLAMA_API_BASE_URL, OPENAI_API_BASE_URL } from '$lib/constants';
 	import { WEBUI_BASE_URL } from '$lib/constants';
 	import { createOpenAITextStream } from '$lib/apis/streaming';
+	import { queryMemory } from '$lib/apis/memories';
 
 	const i18n = getContext('i18n');
 
@@ -254,6 +255,26 @@
 	const sendPrompt = async (prompt, parentId, modelId = null) => {
 		const _chatId = JSON.parse(JSON.stringify($chatId));
 
+		let userContext = null;
+
+		if ($settings?.memory ?? false) {
+			const res = await queryMemory(localStorage.token, prompt).catch((error) => {
+				toast.error(error);
+				return null;
+			});
+
+			if (res) {
+				userContext = res.documents.reduce((acc, doc, index) => {
+					const createdAtTimestamp = res.metadatas[index][0].created_at;
+					const createdAtDate = new Date(createdAtTimestamp * 1000).toISOString().split('T')[0];
+					acc.push(`${index + 1}. [${createdAtDate}]. ${doc[0]}`);
+					return acc;
+				}, []);
+
+				console.log(userContext);
+			}
+		}
+
 		await Promise.all(
 			(modelId ? [modelId] : atSelectedModel !== '' ? [atSelectedModel.id] : selectedModels).map(
 				async (modelId) => {
@@ -270,6 +291,7 @@
 							role: 'assistant',
 							content: '',
 							model: model.id,
+							userContext: userContext,
 							timestamp: Math.floor(Date.now() / 1000) // Unix epoch
 						};
 
@@ -311,10 +333,13 @@
 		scrollToBottom();
 
 		const messagesBody = [
-			$settings.system
+			$settings.system || responseMessage?.userContext
 				? {
 						role: 'system',
-						content: $settings.system
+						content:
+							$settings.system + (responseMessage?.userContext ?? null)
+								? `\n\nUser Context:\n${responseMessage.userContext.join('\n')}`
+								: ''
 				  }
 				: undefined,
 			...messages
@@ -567,10 +592,13 @@
 					model: model.id,
 					stream: true,
 					messages: [
-						$settings.system
+						$settings.system || responseMessage?.userContext
 							? {
 									role: 'system',
-									content: $settings.system
+									content:
+										$settings.system + (responseMessage?.userContext ?? null)
+											? `\n\nUser Context:\n${responseMessage.userContext.join('\n')}`
+											: ''
 							  }
 							: undefined,
 						...messages

+ 32 - 4
src/routes/(app)/c/[id]/+page.svelte

@@ -43,6 +43,7 @@
 		WEBUI_BASE_URL
 	} from '$lib/constants';
 	import { createOpenAITextStream } from '$lib/apis/streaming';
+	import { queryMemory } from '$lib/apis/memories';
 
 	const i18n = getContext('i18n');
 
@@ -260,6 +261,26 @@
 	const sendPrompt = async (prompt, parentId, modelId = null) => {
 		const _chatId = JSON.parse(JSON.stringify($chatId));
 
+		let userContext = null;
+
+		if ($settings?.memory ?? false) {
+			const res = await queryMemory(localStorage.token, prompt).catch((error) => {
+				toast.error(error);
+				return null;
+			});
+
+			if (res) {
+				userContext = res.documents.reduce((acc, doc, index) => {
+					const createdAtTimestamp = res.metadatas[index][0].created_at;
+					const createdAtDate = new Date(createdAtTimestamp * 1000).toISOString().split('T')[0];
+					acc.push(`${index + 1}. [${createdAtDate}]. ${doc[0]}`);
+					return acc;
+				}, []);
+
+				console.log(userContext);
+			}
+		}
+
 		await Promise.all(
 			(modelId ? [modelId] : atSelectedModel !== '' ? [atSelectedModel.id] : selectedModels).map(
 				async (modelId) => {
@@ -317,10 +338,13 @@
 		scrollToBottom();
 
 		const messagesBody = [
-			$settings.system
+			$settings.system || responseMessage?.userContext
 				? {
 						role: 'system',
-						content: $settings.system
+						content:
+							$settings.system + (responseMessage?.userContext ?? null)
+								? `\n\nUser Context:\n${responseMessage.userContext.join('\n')}`
+								: ''
 				  }
 				: undefined,
 			...messages
@@ -573,10 +597,13 @@
 					model: model.id,
 					stream: true,
 					messages: [
-						$settings.system
+						$settings.system || responseMessage?.userContext
 							? {
 									role: 'system',
-									content: $settings.system
+									content:
+										$settings.system + (responseMessage?.userContext ?? null)
+											? `\n\nUser Context:\n${responseMessage.userContext.join('\n')}`
+											: ''
 							  }
 							: undefined,
 						...messages
@@ -705,6 +732,7 @@
 		} catch (error) {
 			await handleOpenAIError(error, null, model, responseMessage);
 		}
+		messages = messages;
 
 		stopResponseFlag = false;
 		await tick();