Browse Source

feat: collection rag integration

Timothy J. Baek 1 year ago
parent
commit
683650ec00

+ 6 - 6
backend/apps/rag/main.py

@@ -97,15 +97,15 @@ async def get_status():
     return {"status": True}
 
 
-class QueryCollectionForm(BaseModel):
+class QueryDocForm(BaseModel):
     collection_name: str
     query: str
     k: Optional[int] = 4
 
 
-@app.post("/query/collection")
-def query_collection(
-    form_data: QueryCollectionForm,
+@app.post("/query/doc")
+def query_doc(
+    form_data: QueryDocForm,
     user=Depends(get_current_user),
 ):
     try:
@@ -173,8 +173,8 @@ def merge_and_sort_query_results(query_results, k):
     return merged_query_results
 
 
-@app.post("/query/collections")
-def query_collections(
+@app.post("/query/collection")
+def query_collection(
     form_data: QueryCollectionsForm,
     user=Depends(get_current_user),
 ):

+ 39 - 2
src/lib/apis/rag/index.ts

@@ -64,7 +64,7 @@ export const uploadWebToVectorDB = async (token: string, collection_name: string
 	return res;
 };
 
-export const queryCollection = async (
+export const queryDoc = async (
 	token: string,
 	collection_name: string,
 	query: string,
@@ -72,7 +72,7 @@ export const queryCollection = async (
 ) => {
 	let error = null;
 
-	const res = await fetch(`${RAG_API_BASE_URL}/query/collection`, {
+	const res = await fetch(`${RAG_API_BASE_URL}/query/doc`, {
 		method: 'POST',
 		headers: {
 			Accept: 'application/json',
@@ -101,6 +101,43 @@ export const queryCollection = async (
 	return res;
 };
 
+export const queryCollection = async (
+	token: string,
+	collection_names: string,
+	query: string,
+	k: number
+) => {
+	let error = null;
+
+	const res = await fetch(`${RAG_API_BASE_URL}/query/collection`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			collection_names: collection_names,
+			query: query,
+			k: k
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			error = err.detail;
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
 export const resetVectorDB = async (token: string) => {
 	let error = null;
 

+ 29 - 0
src/lib/components/chat/Messages/UserMessage.svelte

@@ -117,6 +117,35 @@
 										<div class=" text-gray-500 text-sm">Document</div>
 									</div>
 								</button>
+							{:else if file.type === 'collection'}
+								<button
+									class="h-16 w-[15rem] flex items-center space-x-3 px-2.5 dark:bg-gray-600 rounded-xl border border-gray-200 dark:border-none text-left"
+									type="button"
+								>
+									<div class="p-2.5 bg-red-400 text-white rounded-lg">
+										<svg
+											xmlns="http://www.w3.org/2000/svg"
+											viewBox="0 0 24 24"
+											fill="currentColor"
+											class="w-6 h-6"
+										>
+											<path
+												d="M7.5 3.375c0-1.036.84-1.875 1.875-1.875h.375a3.75 3.75 0 0 1 3.75 3.75v1.875C13.5 8.161 14.34 9 15.375 9h1.875A3.75 3.75 0 0 1 21 12.75v3.375C21 17.16 20.16 18 19.125 18h-9.75A1.875 1.875 0 0 1 7.5 16.125V3.375Z"
+											/>
+											<path
+												d="M15 5.25a5.23 5.23 0 0 0-1.279-3.434 9.768 9.768 0 0 1 6.963 6.963A5.23 5.23 0 0 0 17.25 7.5h-1.875A.375.375 0 0 1 15 7.125V5.25ZM4.875 6H6v10.125A3.375 3.375 0 0 0 9.375 19.5H16.5v1.125c0 1.035-.84 1.875-1.875 1.875h-9.75A1.875 1.875 0 0 1 3 20.625V7.875C3 6.839 3.84 6 4.875 6Z"
+											/>
+										</svg>
+									</div>
+
+									<div class="flex flex-col justify-center -space-y-0.5">
+										<div class=" dark:text-gray-100 text-sm font-medium line-clamp-1">
+											#{file.name}
+										</div>
+
+										<div class=" text-gray-500 text-sm">Collection</div>
+									</div>
+								</button>
 							{/if}
 						</div>
 					{/each}

+ 19 - 8
src/routes/(app)/+page.svelte

@@ -28,7 +28,7 @@
 		getTagsById,
 		updateChatById
 	} from '$lib/apis/chats';
-	import { queryCollection } from '$lib/apis/rag';
+	import { queryCollection, queryDoc } from '$lib/apis/rag';
 	import { generateOpenAIChatCompletion } from '$lib/apis/openai';
 
 	import MessageInput from '$lib/components/chat/MessageInput.svelte';
@@ -224,7 +224,9 @@
 
 		const docs = messages
 			.filter((message) => message?.files ?? null)
-			.map((message) => message.files.filter((item) => item.type === 'doc'))
+			.map((message) =>
+				message.files.filter((item) => item.type === 'doc' || item.type === 'collection')
+			)
 			.flat(1);
 
 		console.log(docs);
@@ -234,12 +236,21 @@
 
 			let relevantContexts = await Promise.all(
 				docs.map(async (doc) => {
-					return await queryCollection(localStorage.token, doc.collection_name, query, 4).catch(
-						(error) => {
-							console.log(error);
-							return null;
-						}
-					);
+					if (doc.type === 'collection') {
+						return await queryCollection(localStorage.token, doc.collection_names, query, 4).catch(
+							(error) => {
+								console.log(error);
+								return null;
+							}
+						);
+					} else {
+						return await queryDoc(localStorage.token, doc.collection_name, query, 4).catch(
+							(error) => {
+								console.log(error);
+								return null;
+							}
+						);
+					}
 				})
 			);
 			relevantContexts = relevantContexts.filter((context) => context);

+ 19 - 8
src/routes/(app)/c/[id]/+page.svelte

@@ -29,7 +29,7 @@
 		getTagsById,
 		updateChatById
 	} from '$lib/apis/chats';
-	import { queryCollection } from '$lib/apis/rag';
+	import { queryCollection, queryDoc } from '$lib/apis/rag';
 	import { generateOpenAIChatCompletion } from '$lib/apis/openai';
 
 	import MessageInput from '$lib/components/chat/MessageInput.svelte';
@@ -238,7 +238,9 @@
 
 		const docs = messages
 			.filter((message) => message?.files ?? null)
-			.map((message) => message.files.filter((item) => item.type === 'doc'))
+			.map((message) =>
+				message.files.filter((item) => item.type === 'doc' || item.type === 'collection')
+			)
 			.flat(1);
 
 		console.log(docs);
@@ -248,12 +250,21 @@
 
 			let relevantContexts = await Promise.all(
 				docs.map(async (doc) => {
-					return await queryCollection(localStorage.token, doc.collection_name, query, 4).catch(
-						(error) => {
-							console.log(error);
-							return null;
-						}
-					);
+					if (doc.type === 'collection') {
+						return await queryCollection(localStorage.token, doc.collection_names, query, 4).catch(
+							(error) => {
+								console.log(error);
+								return null;
+							}
+						);
+					} else {
+						return await queryDoc(localStorage.token, doc.collection_name, query, 4).catch(
+							(error) => {
+								console.log(error);
+								return null;
+							}
+						);
+					}
 				})
 			);
 			relevantContexts = relevantContexts.filter((context) => context);