Browse Source

feat: web rag support

Timothy J. Baek 1 year ago
parent
commit
28226a6f97

+ 8 - 3
backend/apps/rag/main.py

@@ -37,7 +37,7 @@ from typing import Optional
 import uuid
 import time
 
-from utils.misc import calculate_sha256
+from utils.misc import calculate_sha256, calculate_sha256_string
 from utils.utils import get_current_user
 from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
 from constants import ERROR_MESSAGES
@@ -124,10 +124,15 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
     try:
         loader = WebBaseLoader(form_data.url)
         data = loader.load()
-        store_data_in_vector_db(data, form_data.collection_name)
+
+        collection_name = form_data.collection_name
+        if collection_name == "":
+            collection_name = calculate_sha256_string(form_data.url)[:63]
+
+        store_data_in_vector_db(data, collection_name)
         return {
             "status": True,
-            "collection_name": form_data.collection_name,
+            "collection_name": collection_name,
             "filename": form_data.url,
         }
     except Exception as e:

+ 10 - 0
backend/utils/misc.py

@@ -24,6 +24,16 @@ def calculate_sha256(file):
     return sha256.hexdigest()
 
 
+def calculate_sha256_string(string):
+    # Create a new SHA-256 hash object
+    sha256_hash = hashlib.sha256()
+    # Update the hash object with the bytes of the input string
+    sha256_hash.update(string.encode("utf-8"))
+    # Get the hexadecimal representation of the hash
+    hashed_string = sha256_hash.hexdigest()
+    return hashed_string
+
+
 def validate_email_format(email: str) -> bool:
     if not re.match(r"[^@]+@[^@]+\.[^@]+", email):
         return False

+ 32 - 1
src/lib/components/chat/MessageInput.svelte

@@ -6,7 +6,7 @@
 
 	import Prompts from './MessageInput/PromptCommands.svelte';
 	import Suggestions from './MessageInput/Suggestions.svelte';
-	import { uploadDocToVectorDB } from '$lib/apis/rag';
+	import { uploadDocToVectorDB, uploadWebToVectorDB } from '$lib/apis/rag';
 	import AddFilesPlaceholder from '../AddFilesPlaceholder.svelte';
 	import { SUPPORTED_FILE_TYPE, SUPPORTED_FILE_EXTENSIONS } from '$lib/constants';
 	import Documents from './MessageInput/Documents.svelte';
@@ -137,6 +137,33 @@
 		}
 	};
 
+	const uploadWeb = async (url) => {
+		console.log(url);
+
+		const doc = {
+			type: 'doc',
+			name: url,
+			collection_name: '',
+			upload_status: false,
+			error: ''
+		};
+
+		try {
+			files = [...files, doc];
+			const res = await uploadWebToVectorDB(localStorage.token, '', url);
+
+			if (res) {
+				doc.upload_status = true;
+				doc.collection_name = res.collection_name;
+				files = files;
+			}
+		} catch (e) {
+			// Remove the failed doc from the files array
+			files = files.filter((f) => f.name !== url);
+			toast.error(e);
+		}
+	};
+
 	onMount(() => {
 		const dropZone = document.querySelector('body');
 
@@ -258,6 +285,10 @@
 					<Documents
 						bind:this={documentsElement}
 						bind:prompt
+						on:url={(e) => {
+							console.log(e);
+							uploadWeb(e.detail);
+						}}
 						on:select={(e) => {
 							console.log(e);
 							files = [

+ 33 - 2
src/lib/components/chat/MessageInput/Documents.svelte

@@ -2,7 +2,7 @@
 	import { createEventDispatcher } from 'svelte';
 
 	import { documents } from '$lib/stores';
-	import { removeFirstHashWord } from '$lib/utils';
+	import { removeFirstHashWord, isValidHttpUrl } from '$lib/utils';
 	import { tick } from 'svelte';
 
 	export let prompt = '';
@@ -37,9 +37,20 @@
 		chatInputElement?.focus();
 		await tick();
 	};
+
+	const confirmSelectWeb = async (url) => {
+		dispatch('url', url);
+
+		prompt = removeFirstHashWord(prompt);
+		const chatInputElement = document.getElementById('chat-textarea');
+
+		await tick();
+		chatInputElement?.focus();
+		await tick();
+	};
 </script>
 
-{#if filteredDocs.length > 0}
+{#if filteredDocs.length > 0 || prompt.split(' ')?.at(0)?.substring(1).startsWith('http')}
 	<div class="md:px-2 mb-3 text-left w-full">
 		<div class="flex w-full rounded-lg border border-gray-100 dark:border-gray-700">
 			<div class=" bg-gray-100 dark:bg-gray-700 w-10 rounded-l-lg text-center">
@@ -55,6 +66,7 @@
 								: ''}"
 							type="button"
 							on:click={() => {
+								console.log(doc);
 								confirmSelect(doc);
 							}}
 							on:mousemove={() => {
@@ -71,6 +83,25 @@
 							</div>
 						</button>
 					{/each}
+
+					{#if prompt.split(' ')?.at(0)?.substring(1).startsWith('http')}
+						<button
+							class="px-3 py-1.5 rounded-lg w-full text-left bg-gray-100 selected-command-option-button"
+							type="button"
+							on:click={() => {
+								const url = prompt.split(' ')?.at(0)?.substring(1);
+								if (isValidHttpUrl(url)) {
+									confirmSelectWeb(url);
+								}
+							}}
+						>
+							<div class=" font-medium text-black line-clamp-1">
+								{prompt.split(' ')?.at(0)?.substring(1)}
+							</div>
+
+							<div class=" text-xs text-gray-600 line-clamp-1">Web</div>
+						</button>
+					{/if}
 				</div>
 			</div>
 		</div>

+ 52 - 31
src/lib/utils/index.ts

@@ -212,8 +212,12 @@ const convertOpenAIMessages = (convo) => {
 		const message = mapping[message_id];
 		currentId = message_id;
 		try {
-				if (messages.length == 0 && (message['message'] == null || 
-				(message['message']['content']['parts']?.[0] == '' && message['message']['content']['text'] == null))) {
+			if (
+				messages.length == 0 &&
+				(message['message'] == null ||
+					(message['message']['content']['parts']?.[0] == '' &&
+						message['message']['content']['text'] == null))
+			) {
 				// Skip chat messages with no content
 				continue;
 			} else {
@@ -222,7 +226,10 @@ const convertOpenAIMessages = (convo) => {
 					parentId: lastId,
 					childrenIds: message['children'] || [],
 					role: message['message']?.['author']?.['role'] !== 'user' ? 'assistant' : 'user',
-					content: message['message']?.['content']?.['parts']?.[0] ||  message['message']?.['content']?.['text'] || '',
+					content:
+						message['message']?.['content']?.['parts']?.[0] ||
+						message['message']?.['content']?.['text'] ||
+						'',
 					model: 'gpt-3.5-turbo',
 					done: true,
 					context: null
@@ -231,7 +238,7 @@ const convertOpenAIMessages = (convo) => {
 				lastId = currentId;
 			}
 		} catch (error) {
-			console.log("Error with", message, "\nError:", error);
+			console.log('Error with', message, '\nError:', error);
 		}
 	}
 
@@ -256,31 +263,31 @@ const validateChat = (chat) => {
 	// Because ChatGPT sometimes has features we can't use like DALL-E or migh have corrupted messages, need to validate
 	const messages = chat.messages;
 
-    // Check if messages array is empty
-    if (messages.length === 0) {
-        return false;
-    }
-
-    // Last message's children should be an empty array
-    const lastMessage = messages[messages.length - 1];
-    if (lastMessage.childrenIds.length !== 0) {
-        return false;
-    }
-
-    // First message's parent should be null
-    const firstMessage = messages[0];
-    if (firstMessage.parentId !== null) {
-        return false;
-    }
-
-    // Every message's content should be a string
-    for (let message of messages) {
-        if (typeof message.content !== 'string') {
-            return false;
-        }
-    }
-
-    return true;
+	// Check if messages array is empty
+	if (messages.length === 0) {
+		return false;
+	}
+
+	// Last message's children should be an empty array
+	const lastMessage = messages[messages.length - 1];
+	if (lastMessage.childrenIds.length !== 0) {
+		return false;
+	}
+
+	// First message's parent should be null
+	const firstMessage = messages[0];
+	if (firstMessage.parentId !== null) {
+		return false;
+	}
+
+	// Every message's content should be a string
+	for (let message of messages) {
+		if (typeof message.content !== 'string') {
+			return false;
+		}
+	}
+
+	return true;
 };
 
 export const convertOpenAIChats = (_chats) => {
@@ -298,8 +305,22 @@ export const convertOpenAIChats = (_chats) => {
 				chat: chat,
 				timestamp: convo['timestamp']
 			});
-		} else { failed ++}
+		} else {
+			failed++;
+		}
 	}
-	console.log(failed, "Conversations could not be imported");
+	console.log(failed, 'Conversations could not be imported');
 	return chats;
 };
+
+export const isValidHttpUrl = (string) => {
+	let url;
+
+	try {
+		url = new URL(string);
+	} catch (_) {
+		return false;
+	}
+
+	return url.protocol === 'http:' || url.protocol === 'https:';
+};