Browse Source

feat: add rag top k value setting

Timothy J. Baek 1 year ago
parent
commit
47a05a47b4

+ 34 - 14
backend/apps/rag/main.py

@@ -79,6 +79,8 @@ app.state.CHUNK_SIZE = CHUNK_SIZE
 app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
 app.state.RAG_TEMPLATE = RAG_TEMPLATE
 app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
+app.state.TOP_K = 4
+
 app.state.sentence_transformer_ef = (
     embedding_functions.SentenceTransformerEmbeddingFunction(
         model_name=app.state.RAG_EMBEDDING_MODEL,
@@ -210,23 +212,33 @@ async def get_rag_template(user=Depends(get_current_user)):
     }
 
 
-class RAGTemplateForm(BaseModel):
-    template: str
+@app.get("/query/settings")
+async def get_query_settings(user=Depends(get_admin_user)):
+    return {
+        "status": True,
+        "template": app.state.RAG_TEMPLATE,
+        "k": app.state.TOP_K,
+    }
 
 
-@app.post("/template/update")
-async def update_rag_template(form_data: RAGTemplateForm, user=Depends(get_admin_user)):
-    # TODO: check template requirements
-    app.state.RAG_TEMPLATE = (
-        form_data.template if form_data.template != "" else RAG_TEMPLATE
-    )
+class QuerySettingsForm(BaseModel):
+    k: Optional[int] = None
+    template: Optional[str] = None
+
+
+@app.post("/query/settings/update")
+async def update_query_settings(
+    form_data: QuerySettingsForm, user=Depends(get_admin_user)
+):
+    app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE
+    app.state.TOP_K = form_data.k if form_data.k else 4
     return {"status": True, "template": app.state.RAG_TEMPLATE}
 
 
 class QueryDocForm(BaseModel):
     collection_name: str
     query: str
-    k: Optional[int] = 4
+    k: Optional[int] = None
 
 
 @app.post("/query/doc")
@@ -240,7 +252,10 @@ def query_doc(
             name=form_data.collection_name,
             embedding_function=app.state.sentence_transformer_ef,
         )
-        result = collection.query(query_texts=[form_data.query], n_results=form_data.k)
+        result = collection.query(
+            query_texts=[form_data.query],
+            n_results=form_data.k if form_data.k else app.state.TOP_K,
+        )
         return result
     except Exception as e:
         print(e)
@@ -253,7 +268,7 @@ def query_doc(
 class QueryCollectionsForm(BaseModel):
     collection_names: List[str]
     query: str
-    k: Optional[int] = 4
+    k: Optional[int] = None
 
 
 def merge_and_sort_query_results(query_results, k):
@@ -317,13 +332,16 @@ def query_collection(
             )
 
             result = collection.query(
-                query_texts=[form_data.query], n_results=form_data.k
+                query_texts=[form_data.query],
+                n_results=form_data.k if form_data.k else app.state.TOP_K,
             )
             results.append(result)
         except:
             pass
 
-    return merge_and_sort_query_results(results, form_data.k)
+    return merge_and_sort_query_results(
+        results, form_data.k if form_data.k else app.state.TOP_K
+    )
 
 
 @app.post("/web")
@@ -423,7 +441,9 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
         "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
     ] or file_ext in ["xls", "xlsx"]:
         loader = UnstructuredExcelLoader(file_path)
-    elif file_ext in known_source_ext or (file_content_type and file_content_type.find("text/") >= 0):
+    elif file_ext in known_source_ext or (
+        file_content_type and file_content_type.find("text/") >= 0
+    ):
         loader = TextLoader(file_path)
     else:
         loader = TextLoader(file_path)

+ 36 - 4
src/lib/apis/rag/index.ts

@@ -85,17 +85,49 @@ export const getRAGTemplate = async (token: string) => {
 	return res?.template ?? '';
 };
 
-export const updateRAGTemplate = async (token: string, template: string) => {
+export const getQuerySettings = async (token: string) => {
 	let error = null;
 
-	const res = await fetch(`${RAG_API_BASE_URL}/template/update`, {
+	const res = await fetch(`${RAG_API_BASE_URL}/query/settings`, {
+		method: 'GET',
+		headers: {
+			'Content-Type': 'application/json',
+			Authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			error = err.detail;
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+type QuerySettings = {
+	k: number | null;
+	template: string | null;
+};
+
+export const updateQuerySettings = async (token: string, settings: QuerySettings) => {
+	let error = null;
+
+	const res = await fetch(`${RAG_API_BASE_URL}/query/settings/update`, {
 		method: 'POST',
 		headers: {
 			'Content-Type': 'application/json',
 			Authorization: `Bearer ${token}`
 		},
 		body: JSON.stringify({
-			template: template
+			...settings
 		})
 	})
 		.then(async (res) => {
@@ -183,7 +215,7 @@ export const queryDoc = async (
 	token: string,
 	collection_name: string,
 	query: string,
-	k: number
+	k: number | null = null
 ) => {
 	let error = null;
 

+ 43 - 6
src/lib/components/documents/Settings/General.svelte

@@ -2,10 +2,10 @@
 	import { getDocs } from '$lib/apis/documents';
 	import {
 		getChunkParams,
-		getRAGTemplate,
+		getQuerySettings,
 		scanDocs,
 		updateChunkParams,
-		updateRAGTemplate
+		updateQuerySettings
 	} from '$lib/apis/rag';
 	import { documents } from '$lib/stores';
 	import { onMount } from 'svelte';
@@ -18,7 +18,10 @@
 	let chunkSize = 0;
 	let chunkOverlap = 0;
 
-	let template = '';
+	let querySettings = {
+		template: '',
+		k: 4
+	};
 
 	const scanHandler = async () => {
 		loading = true;
@@ -33,7 +36,7 @@
 
 	const submitHandler = async () => {
 		const res = await updateChunkParams(localStorage.token, chunkSize, chunkOverlap);
-		await updateRAGTemplate(localStorage.token, template);
+		querySettings = await updateQuerySettings(localStorage.token, querySettings);
 	};
 
 	onMount(async () => {
@@ -44,7 +47,7 @@
 			chunkOverlap = res.chunk_overlap;
 		}
 
-		template = await getRAGTemplate(localStorage.token);
+		querySettings = await getQuerySettings(localStorage.token);
 	});
 </script>
 
@@ -156,10 +159,44 @@
 				</div>
 			</div>
 
+			<div class=" text-sm font-medium">Query Params</div>
+
+			<div class=" flex">
+				<div class="  flex w-full justify-between">
+					<div class="self-center text-xs font-medium flex-1">Top K</div>
+
+					<div class="self-center p-3">
+						<input
+							class=" w-full rounded py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none border border-gray-100 dark:border-gray-600"
+							type="number"
+							placeholder="Enter Top K"
+							bind:value={querySettings.k}
+							autocomplete="off"
+							min="0"
+						/>
+					</div>
+				</div>
+
+				<!-- <div class="flex w-full">
+					<div class=" self-center text-xs font-medium min-w-fit">Chunk Overlap</div>
+
+					<div class="self-center p-3">
+						<input
+							class="w-full rounded py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none border border-gray-100 dark:border-gray-600"
+							type="number"
+							placeholder="Enter Chunk Overlap"
+							bind:value={chunkOverlap}
+							autocomplete="off"
+							min="0"
+						/>
+					</div>
+				</div> -->
+			</div>
+
 			<div>
 				<div class=" mb-2.5 text-sm font-medium">RAG Template</div>
 				<textarea
-					bind:value={template}
+					bind:value={querySettings.template}
 					class="w-full rounded p-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none resize-none"
 					rows="4"
 				/>

+ 5 - 7
src/routes/(app)/+page.svelte

@@ -248,19 +248,17 @@
 			let relevantContexts = await Promise.all(
 				docs.map(async (doc) => {
 					if (doc.type === 'collection') {
-						return await queryCollection(localStorage.token, doc.collection_names, query, 4).catch(
+						return await queryCollection(localStorage.token, doc.collection_names, query).catch(
 							(error) => {
 								console.log(error);
 								return null;
 							}
 						);
 					} else {
-						return await queryDoc(localStorage.token, doc.collection_name, query, 4).catch(
-							(error) => {
-								console.log(error);
-								return null;
-							}
-						);
+						return await queryDoc(localStorage.token, doc.collection_name, query).catch((error) => {
+							console.log(error);
+							return null;
+						});
 					}
 				})
 			);

+ 5 - 7
src/routes/(app)/c/[id]/+page.svelte

@@ -261,19 +261,17 @@
 			let relevantContexts = await Promise.all(
 				docs.map(async (doc) => {
 					if (doc.type === 'collection') {
-						return await queryCollection(localStorage.token, doc.collection_names, query, 4).catch(
+						return await queryCollection(localStorage.token, doc.collection_names, query).catch(
 							(error) => {
 								console.log(error);
 								return null;
 							}
 						);
 					} else {
-						return await queryDoc(localStorage.token, doc.collection_name, query, 4).catch(
-							(error) => {
-								console.log(error);
-								return null;
-							}
-						);
+						return await queryDoc(localStorage.token, doc.collection_name, query).catch((error) => {
+							console.log(error);
+							return null;
+						});
 					}
 				})
 			);