Browse Source

feat: frontend file upload support

Timothy J. Baek 1 year ago
parent
commit
fef4725d56

+ 2 - 2
backend/apps/rag/main.py

@@ -91,7 +91,7 @@ def store_web(form_data: StoreWebForm):
         loader = WebBaseLoader(form_data.url)
         data = loader.load()
         store_data_in_vector_db(data, form_data.collection_name)
-        return {"status": True}
+        return {"status": True, "collection_name": form_data.collection_name}
     except Exception as e:
         print(e)
         raise HTTPException(
@@ -129,7 +129,7 @@ def store_doc(collection_name: str = Form(...), file: UploadFile = File(...)):
 
         data = loader.load()
         store_data_in_vector_db(data, collection_name)
-        return {"status": True}
+        return {"status": True, "collection_name": collection_name}
     except Exception as e:
         print(e)
         raise HTTPException(

+ 0 - 3
src/lib/apis/rag/index.ts

@@ -11,7 +11,6 @@ export const uploadDocToVectorDB = async (token: string, collection_name: string
 		method: 'POST',
 		headers: {
 			Accept: 'application/json',
-			'Content-Type': 'application/json',
 			authorization: `Bearer ${token}`
 		},
 		body: data
@@ -85,7 +84,6 @@ export const queryVectorDB = async (
 			method: 'GET',
 			headers: {
 				Accept: 'application/json',
-				'Content-Type': 'application/json',
 				authorization: `Bearer ${token}`
 			}
 		}
@@ -96,7 +94,6 @@ export const queryVectorDB = async (
 		})
 		.catch((err) => {
 			error = err.detail;
-			console.log(err);
 			return null;
 		});
 

+ 63 - 14
src/lib/components/chat/MessageInput.svelte

@@ -2,10 +2,11 @@
 	import toast from 'svelte-french-toast';
 	import { onMount, tick } from 'svelte';
 	import { settings } from '$lib/stores';
-	import { findWordIndices } from '$lib/utils';
+	import { calculateSHA256, findWordIndices } from '$lib/utils';
 
 	import Prompts from './MessageInput/PromptCommands.svelte';
 	import Suggestions from './MessageInput/Suggestions.svelte';
+	import { uploadDocToVectorDB } from '$lib/apis/rag';
 
 	export let submitPrompt: Function;
 	export let stopResponse: Function;
@@ -98,7 +99,7 @@
 			dragged = true;
 		});
 
-		dropZone.addEventListener('drop', (e) => {
+		dropZone.addEventListener('drop', async (e) => {
 			e.preventDefault();
 			console.log(e);
 
@@ -115,14 +116,30 @@
 					];
 				};
 
-				if (
-					e.dataTransfer?.files &&
-					e.dataTransfer?.files.length > 0 &&
-					['image/gif', 'image/jpeg', 'image/png'].includes(e.dataTransfer?.files[0]['type'])
-				) {
-					reader.readAsDataURL(e.dataTransfer?.files[0]);
+				if (e.dataTransfer?.files && e.dataTransfer?.files.length > 0) {
+					const file = e.dataTransfer?.files[0];
+					if (['image/gif', 'image/jpeg', 'image/png'].includes(file['type'])) {
+						reader.readAsDataURL(file);
+					} else if (['application/pdf', 'text/plain'].includes(file['type'])) {
+						console.log(file);
+						const hash = await calculateSHA256(file);
+						// const res = uploadDocToVectorDB(localStorage.token, hash,file);
+
+						if (true) {
+							files = [
+								...files,
+								{
+									type: 'doc',
+									name: file.name,
+									collection_name: hash
+								}
+							];
+						}
+					} else {
+						toast.error(`Unsupported File Type '${file['type']}'.`);
+					}
 				} else {
-					toast.error(`Unsupported File Type '${e.dataTransfer?.files[0]['type']}'.`);
+					toast.error(`File not found.`);
 				}
 			}
 
@@ -145,11 +162,11 @@
 		<div class="absolute rounded-xl w-full h-full backdrop-blur bg-gray-800/40 flex justify-center">
 			<div class="m-auto pt-64 flex flex-col justify-center">
 				<div class="max-w-md">
-					<div class="  text-center text-6xl mb-3">🏞️</div>
-					<div class="text-center dark:text-white text-2xl font-semibold z-50">Add Images</div>
+					<div class="  text-center text-6xl mb-3">🗂️</div>
+					<div class="text-center dark:text-white text-2xl font-semibold z-50">Add Files</div>
 
 					<div class=" mt-2 text-center text-sm dark:text-gray-200 w-full">
-						Drop any images here to add to the conversation
+						Drop any files/images here to add to the conversation
 					</div>
 				</div>
 			</div>
@@ -237,10 +254,42 @@
 					}}
 				>
 					{#if files.length > 0}
-						<div class="ml-2 mt-2 mb-1 flex space-x-2">
+						<div class="mx-2 mt-2 mb-1 flex flex-wrap gap-2">
 							{#each files as file, fileIdx}
 								<div class=" relative group">
-									<img src={file.url} alt="input" class=" h-16 w-16 rounded-xl object-cover" />
+									{#if file.type === 'image'}
+										<img src={file.url} alt="input" class=" h-16 w-16 rounded-xl object-cover" />
+									{:else if file.type === 'doc'}
+										<div
+											class="h-16 w-[15rem] flex items-center space-x-3 px-2 bg-gray-600 rounded-xl"
+										>
+											<div class="p-2.5 bg-red-400 rounded-lg">
+												<svg
+													xmlns="http://www.w3.org/2000/svg"
+													viewBox="0 0 24 24"
+													fill="currentColor"
+													class="w-6 h-6"
+												>
+													<path
+														fill-rule="evenodd"
+														d="M5.625 1.5c-1.036 0-1.875.84-1.875 1.875v17.25c0 1.035.84 1.875 1.875 1.875h12.75c1.035 0 1.875-.84 1.875-1.875V12.75A3.75 3.75 0 0 0 16.5 9h-1.875a1.875 1.875 0 0 1-1.875-1.875V5.25A3.75 3.75 0 0 0 9 1.5H5.625ZM7.5 15a.75.75 0 0 1 .75-.75h7.5a.75.75 0 0 1 0 1.5h-7.5A.75.75 0 0 1 7.5 15Zm.75 2.25a.75.75 0 0 0 0 1.5H12a.75.75 0 0 0 0-1.5H8.25Z"
+														clip-rule="evenodd"
+													/>
+													<path
+														d="M12.971 1.816A5.23 5.23 0 0 1 14.25 5.25v1.875c0 .207.168.375.375.375H16.5a5.23 5.23 0 0 1 3.434 1.279 9.768 9.768 0 0 0-6.963-6.963Z"
+													/>
+												</svg>
+											</div>
+
+											<div class="flex flex-col justify-center -space-y-0.5">
+												<div class=" text-gray-100 text-sm line-clamp-1">
+													{file.name}
+												</div>
+
+												<div class=" text-gray-500 text-sm">Document</div>
+											</div>
+										</div>
+									{/if}
 
 									<div class=" absolute -top-1 -right-1">
 										<button

+ 35 - 0
src/lib/utils/index.ts

@@ -127,3 +127,38 @@ export const findWordIndices = (text) => {
 
 	return matches;
 };
+
+export const calculateSHA256 = async (file) => {
+	console.log(file);
+	// Create a FileReader to read the file asynchronously
+	const reader = new FileReader();
+
+	// Define a promise to handle the file reading
+	const readFile = new Promise((resolve, reject) => {
+		reader.onload = () => resolve(reader.result);
+		reader.onerror = reject;
+	});
+
+	// Read the file as an ArrayBuffer
+	reader.readAsArrayBuffer(file);
+
+	try {
+		// Wait for the FileReader to finish reading the file
+		const buffer = await readFile;
+
+		// Convert the ArrayBuffer to a Uint8Array
+		const uint8Array = new Uint8Array(buffer);
+
+		// Calculate the SHA-256 hash using Web Crypto API
+		const hashBuffer = await crypto.subtle.digest('SHA-256', uint8Array);
+
+		// Convert the hash to a hexadecimal string
+		const hashArray = Array.from(new Uint8Array(hashBuffer));
+		const hashHex = hashArray.map((byte) => byte.toString(16).padStart(2, '0')).join('');
+
+		return `sha256:${hashHex}`;
+	} catch (error) {
+		console.error('Error calculating SHA-256 hash:', error);
+		throw error;
+	}
+};

+ 20 - 0
src/lib/utils/rag/index.ts

@@ -0,0 +1,20 @@
+export const RAGTemplate = (context: string, query: string) => {
+	let template = `Use the following context as your learned knowledge, inside <context></context> XML tags.
+	<context>
+	  [context]
+	</context>
+	
+	When answer to user:
+	- If you don't know, just say that you don't know.
+	- If you don't know when you are not sure, ask for clarification.
+	Avoid mentioning that you obtained the information from the context.
+	And answer according to the language of the user's question.
+			
+	Given the context information, answer the query.
+	Query: [query]`;
+
+	template = template.replace(/\[context\]/g, context);
+	template = template.replace(/\[query\]/g, query);
+
+	return template;
+};

+ 103 - 73
src/routes/(app)/+page.svelte

@@ -7,16 +7,18 @@
 	import { page } from '$app/stores';
 
 	import { models, modelfiles, user, settings, chats, chatId, config } from '$lib/stores';
+	import { copyToClipboard, splitStream } from '$lib/utils';
 
 	import { generateChatCompletion, generateTitle } from '$lib/apis/ollama';
-	import { copyToClipboard, splitStream } from '$lib/utils';
+	import { createNewChat, getChatList, updateChatById } from '$lib/apis/chats';
+	import { queryVectorDB } from '$lib/apis/rag';
+	import { generateOpenAIChatCompletion } from '$lib/apis/openai';
 
 	import MessageInput from '$lib/components/chat/MessageInput.svelte';
 	import Messages from '$lib/components/chat/Messages.svelte';
 	import ModelSelector from '$lib/components/chat/ModelSelector.svelte';
 	import Navbar from '$lib/components/layout/Navbar.svelte';
-	import { createNewChat, getChatList, updateChatById } from '$lib/apis/chats';
-	import { generateOpenAIChatCompletion } from '$lib/apis/openai';
+	import { RAGTemplate } from '$lib/utils/rag';
 
 	let stopResponseFlag = false;
 	let autoScroll = true;
@@ -113,8 +115,103 @@
 	// Ollama functions
 	//////////////////////////
 
+	const submitPrompt = async (userPrompt) => {
+		console.log('submitPrompt', $chatId);
+
+		if (selectedModels.includes('')) {
+			toast.error('Model not selected');
+		} else if (messages.length != 0 && messages.at(-1).done != true) {
+			// Response not done
+			console.log('wait');
+		} else {
+			// Reset chat message textarea height
+			document.getElementById('chat-textarea').style.height = '';
+
+			// Create user message
+			let userMessageId = uuidv4();
+			let userMessage = {
+				id: userMessageId,
+				parentId: messages.length !== 0 ? messages.at(-1).id : null,
+				childrenIds: [],
+				role: 'user',
+				content: userPrompt,
+				files: files.length > 0 ? files : undefined
+			};
+
+			// Add message to history and Set currentId to messageId
+			history.messages[userMessageId] = userMessage;
+			history.currentId = userMessageId;
+
+			// Append messageId to childrenIds of parent message
+			if (messages.length !== 0) {
+				history.messages[messages.at(-1).id].childrenIds.push(userMessageId);
+			}
+
+			// Wait until history/message have been updated
+			await tick();
+
+			// Create new chat if only one message in messages
+			if (messages.length == 1) {
+				if ($settings.saveChatHistory ?? true) {
+					chat = await createNewChat(localStorage.token, {
+						id: $chatId,
+						title: 'New Chat',
+						models: selectedModels,
+						system: $settings.system ?? undefined,
+						options: {
+							...($settings.options ?? {})
+						},
+						messages: messages,
+						history: history,
+						timestamp: Date.now()
+					});
+					await chats.set(await getChatList(localStorage.token));
+					await chatId.set(chat.id);
+				} else {
+					await chatId.set('local');
+				}
+				await tick();
+			}
+
+			// Reset chat input textarea
+			prompt = '';
+			files = [];
+
+			// Send prompt
+			await sendPrompt(userPrompt, userMessageId);
+		}
+	};
+
 	const sendPrompt = async (prompt, parentId) => {
 		const _chatId = JSON.parse(JSON.stringify($chatId));
+
+		// TODO: update below to include all ancestral files
+		const docs = history.messages[parentId].files.filter((item) => item.type === 'file');
+
+		if (docs.length > 0) {
+			const query = history.messages[parentId].content;
+
+			let relevantContexts = await Promise.all(
+				docs.map(async (doc) => {
+					return await queryVectorDB(localStorage.token, doc.collection_name, query, 4).catch(
+						(error) => {
+							console.log(error);
+							return null;
+						}
+					);
+				})
+			);
+			relevantContexts = relevantContexts.filter((context) => context);
+
+			const contextString = relevantContexts.reduce((a, context, i, arr) => {
+				return `${a}${context.documents.join(' ')}\n`;
+			}, '');
+
+			history.messages[parentId].raContent = RAGTemplate(contextString, query);
+			history.messages[parentId].contexts = relevantContexts;
+			await tick();
+		}
+
 		await Promise.all(
 			selectedModels.map(async (model) => {
 				console.log(model);
@@ -177,7 +274,7 @@
 				.filter((message) => message)
 				.map((message) => ({
 					role: message.role,
-					content: message.content,
+					content: message?.raContent ?? message.content,
 					...(message.files && {
 						images: message.files
 							.filter((file) => file.type === 'image')
@@ -366,7 +463,7 @@
 								content: [
 									{
 										type: 'text',
-										text: message.content
+										text: message?.raContent ?? message.content
 									},
 									...message.files
 										.filter((file) => file.type === 'image')
@@ -378,7 +475,7 @@
 										}))
 								]
 						  }
-						: { content: message.content })
+						: { content: message?.raContent ?? message.content })
 				})),
 			seed: $settings?.options?.seed ?? undefined,
 			stop: $settings?.options?.stop ?? undefined,
@@ -494,73 +591,6 @@
 		}
 	};
 
-	const submitPrompt = async (userPrompt) => {
-		console.log('submitPrompt', $chatId);
-
-		if (selectedModels.includes('')) {
-			toast.error('Model not selected');
-		} else if (messages.length != 0 && messages.at(-1).done != true) {
-			// Response not done
-			console.log('wait');
-		} else {
-			// Reset chat message textarea height
-			document.getElementById('chat-textarea').style.height = '';
-
-			// Create user message
-			let userMessageId = uuidv4();
-			let userMessage = {
-				id: userMessageId,
-				parentId: messages.length !== 0 ? messages.at(-1).id : null,
-				childrenIds: [],
-				role: 'user',
-				content: userPrompt,
-				files: files.length > 0 ? files : undefined
-			};
-
-			// Add message to history and Set currentId to messageId
-			history.messages[userMessageId] = userMessage;
-			history.currentId = userMessageId;
-
-			// Append messageId to childrenIds of parent message
-			if (messages.length !== 0) {
-				history.messages[messages.at(-1).id].childrenIds.push(userMessageId);
-			}
-
-			// Wait until history/message have been updated
-			await tick();
-
-			// Create new chat if only one message in messages
-			if (messages.length == 1) {
-				if ($settings.saveChatHistory ?? true) {
-					chat = await createNewChat(localStorage.token, {
-						id: $chatId,
-						title: 'New Chat',
-						models: selectedModels,
-						system: $settings.system ?? undefined,
-						options: {
-							...($settings.options ?? {})
-						},
-						messages: messages,
-						history: history,
-						timestamp: Date.now()
-					});
-					await chats.set(await getChatList(localStorage.token));
-					await chatId.set(chat.id);
-				} else {
-					await chatId.set('local');
-				}
-				await tick();
-			}
-
-			// Reset chat input textarea
-			prompt = '';
-			files = [];
-
-			// Send prompt
-			await sendPrompt(userPrompt, userMessageId);
-		}
-	};
-
 	const stopResponse = () => {
 		stopResponseFlag = true;
 		console.log('stopResponse');