Browse Source

Merge pull request #779 from open-webui/editable-chunk-params

feat: editable chunk params
Timothy Jaeryang Baek 1 year ago
parent
commit
b993b66cfb

+ 61 - 2
backend/apps/rag/main.py

@@ -62,6 +62,7 @@ from config import (
     CHROMA_CLIENT,
     CHUNK_SIZE,
     CHUNK_OVERLAP,
+    RAG_TEMPLATE,
 )
 from constants import ERROR_MESSAGES
 
@@ -71,6 +72,11 @@ from constants import ERROR_MESSAGES
 
 app = FastAPI()
 
+app.state.CHUNK_SIZE = CHUNK_SIZE
+app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
+app.state.RAG_TEMPLATE = RAG_TEMPLATE
+
+
 origins = ["*"]
 
 app.add_middleware(
@@ -92,7 +98,7 @@ class StoreWebForm(CollectionNameForm):
 
 def store_data_in_vector_db(data, collection_name) -> bool:
     text_splitter = RecursiveCharacterTextSplitter(
-        chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP
+        chunk_size=app.state.CHUNK_SIZE, chunk_overlap=app.state.CHUNK_OVERLAP
     )
     docs = text_splitter.split_documents(data)
 
@@ -116,7 +122,60 @@ def store_data_in_vector_db(data, collection_name) -> bool:
 
 @app.get("/")
 async def get_status():
-    return {"status": True}
+    return {
+        "status": True,
+        "chunk_size": app.state.CHUNK_SIZE,
+        "chunk_overlap": app.state.CHUNK_OVERLAP,
+    }
+
+
+@app.get("/chunk")
+async def get_chunk_params(user=Depends(get_admin_user)):
+    return {
+        "status": True,
+        "chunk_size": app.state.CHUNK_SIZE,
+        "chunk_overlap": app.state.CHUNK_OVERLAP,
+    }
+
+
+class ChunkParamUpdateForm(BaseModel):
+    chunk_size: int
+    chunk_overlap: int
+
+
+@app.post("/chunk/update")
+async def update_chunk_params(
+    form_data: ChunkParamUpdateForm, user=Depends(get_admin_user)
+):
+    app.state.CHUNK_SIZE = form_data.chunk_size
+    app.state.CHUNK_OVERLAP = form_data.chunk_overlap
+
+    return {
+        "status": True,
+        "chunk_size": app.state.CHUNK_SIZE,
+        "chunk_overlap": app.state.CHUNK_OVERLAP,
+    }
+
+
+@app.get("/template")
+async def get_rag_template(user=Depends(get_current_user)):
+    return {
+        "status": True,
+        "template": app.state.RAG_TEMPLATE,
+    }
+
+
+class RAGTemplateForm(BaseModel):
+    template: str
+
+
+@app.post("/template/update")
+async def update_rag_template(form_data: RAGTemplateForm, user=Depends(get_admin_user)):
+    # TODO: check template requirements
+    app.state.RAG_TEMPLATE = (
+        form_data.template if form_data.template != "" else RAG_TEMPLATE
+    )
+    return {"status": True, "template": app.state.RAG_TEMPLATE}
 
 
 class QueryDocForm(BaseModel):

+ 15 - 0
backend/config.py

@@ -144,6 +144,21 @@ CHROMA_CLIENT = chromadb.PersistentClient(
 CHUNK_SIZE = 1500
 CHUNK_OVERLAP = 100
 
+
+RAG_TEMPLATE = """Use the following context as your learned knowledge, inside <context></context> XML tags.
+<context>
+    [context]
+</context>
+
+When answer to user:
+- If you don't know, just say that you don't know.
+- If you don't know when you are not sure, ask for clarification.
+Avoid mentioning that you obtained the information from the context.
+And answer according to the language of the user's question.
+        
+Given the context information, answer the query.
+Query: [query]"""
+
 ####################################
 # Transcribe
 ####################################

+ 115 - 0
src/lib/apis/rag/index.ts

@@ -1,5 +1,120 @@
 import { RAG_API_BASE_URL } from '$lib/constants';
 
+export const getChunkParams = async (token: string) => {
+	let error = null;
+
+	const res = await fetch(`${RAG_API_BASE_URL}/chunk`, {
+		method: 'GET',
+		headers: {
+			'Content-Type': 'application/json',
+			Authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			error = err.detail;
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const updateChunkParams = async (token: string, size: number, overlap: number) => {
+	let error = null;
+
+	const res = await fetch(`${RAG_API_BASE_URL}/chunk/update`, {
+		method: 'POST',
+		headers: {
+			'Content-Type': 'application/json',
+			Authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			chunk_size: size,
+			chunk_overlap: overlap
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			error = err.detail;
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const getRAGTemplate = async (token: string) => {
+	let error = null;
+
+	const res = await fetch(`${RAG_API_BASE_URL}/template`, {
+		method: 'GET',
+		headers: {
+			'Content-Type': 'application/json',
+			Authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			error = err.detail;
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res?.template ?? '';
+};
+
+export const updateRAGTemplate = async (token: string, template: string) => {
+	let error = null;
+
+	const res = await fetch(`${RAG_API_BASE_URL}/template/update`, {
+		method: 'POST',
+		headers: {
+			'Content-Type': 'application/json',
+			Authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			template: template
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			error = err.detail;
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
 export const uploadDocToVectorDB = async (token: string, collection_name: string, file: File) => {
 	const data = new FormData();
 	data.append('file', file);

+ 77 - 5
src/lib/components/documents/Settings/General.svelte

@@ -1,6 +1,12 @@
 <script lang="ts">
 	import { getDocs } from '$lib/apis/documents';
-	import { scanDocs } from '$lib/apis/rag';
+	import {
+		getChunkParams,
+		getRAGTemplate,
+		scanDocs,
+		updateChunkParams,
+		updateRAGTemplate
+	} from '$lib/apis/rag';
 	import { documents } from '$lib/stores';
 	import { onMount } from 'svelte';
 	import toast from 'svelte-french-toast';
@@ -9,6 +15,11 @@
 
 	let loading = false;
 
+	let chunkSize = 0;
+	let chunkOverlap = 0;
+
+	let template = '';
+
 	const scanHandler = async () => {
 		loading = true;
 		const res = await scanDocs(localStorage.token);
@@ -20,13 +31,27 @@
 		}
 	};
 
-	onMount(async () => {});
+	const submitHandler = async () => {
+		const res = await updateChunkParams(localStorage.token, chunkSize, chunkOverlap);
+		await updateRAGTemplate(localStorage.token, template);
+	};
+
+	onMount(async () => {
+		const res = await getChunkParams(localStorage.token);
+
+		if (res) {
+			chunkSize = res.chunk_size;
+			chunkOverlap = res.chunk_overlap;
+		}
+
+		template = await getRAGTemplate(localStorage.token);
+	});
 </script>
 
 <form
 	class="flex flex-col h-full justify-between space-y-3 text-sm"
 	on:submit|preventDefault={() => {
-		// console.log('submit');
+		submitHandler();
 		saveHandler();
 	}}
 >
@@ -93,14 +118,61 @@
 				</button>
 			</div>
 		</div>
+
+		<hr class=" dark:border-gray-700" />
+
+		<div class=" ">
+			<div class=" text-sm font-medium">Chunk Params</div>
+
+			<div class=" flex">
+				<div class="  flex w-full justify-between">
+					<div class="self-center text-xs font-medium min-w-fit">Chunk Size</div>
+
+					<div class="self-center p-3">
+						<input
+							class=" w-full rounded py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none border border-gray-100 dark:border-gray-600"
+							type="number"
+							placeholder="Enter Chunk Size"
+							bind:value={chunkSize}
+							autocomplete="off"
+							min="0"
+						/>
+					</div>
+				</div>
+
+				<div class="flex w-full">
+					<div class=" self-center text-xs font-medium min-w-fit">Chunk Overlap</div>
+
+					<div class="self-center p-3">
+						<input
+							class="w-full rounded py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none border border-gray-100 dark:border-gray-600"
+							type="number"
+							placeholder="Enter Chunk Overlap"
+							bind:value={chunkOverlap}
+							autocomplete="off"
+							min="0"
+						/>
+					</div>
+				</div>
+			</div>
+
+			<div>
+				<div class=" mb-2.5 text-sm font-medium">RAG Template</div>
+				<textarea
+					bind:value={template}
+					class="w-full rounded p-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none resize-none"
+					rows="4"
+				/>
+			</div>
+		</div>
 	</div>
 
-	<!-- <div class="flex justify-end pt-3 text-sm font-medium">
+	<div class="flex justify-end pt-3 text-sm font-medium">
 		<button
 			class=" px-4 py-2 bg-emerald-600 hover:bg-emerald-700 text-gray-100 transition rounded"
 			type="submit"
 		>
 			Save
 		</button>
-	</div> -->
+	</div>
 </form>

+ 18 - 14
src/lib/utils/rag/index.ts

@@ -1,17 +1,21 @@
-export const RAGTemplate = (context: string, query: string) => {
-	let template = `Use the following context as your learned knowledge, inside <context></context> XML tags.
-	<context>
-	  [context]
-	</context>
-	
-	When answer to user:
-	- If you don't know, just say that you don't know.
-	- If you don't know when you are not sure, ask for clarification.
-	Avoid mentioning that you obtained the information from the context.
-	And answer according to the language of the user's question.
-			
-	Given the context information, answer the query.
-	Query: [query]`;
+import { getRAGTemplate } from '$lib/apis/rag';
+
+export const RAGTemplate = async (token: string, context: string, query: string) => {
+	let template = await getRAGTemplate(token).catch(() => {
+		return `Use the following context as your learned knowledge, inside <context></context> XML tags.
+		<context>
+		  [context]
+		</context>
+		
+		When answer to user:
+		- If you don't know, just say that you don't know.
+		- If you don't know when you are not sure, ask for clarification.
+		Avoid mentioning that you obtained the information from the context.
+		And answer according to the language of the user's question.
+				
+		Given the context information, answer the query.
+		Query: [query]`;
+	});
 
 	template = template.replace(/\[context\]/g, context);
 	template = template.replace(/\[query\]/g, query);

+ 5 - 1
src/routes/(app)/+page.svelte

@@ -266,7 +266,11 @@
 
 			console.log(contextString);
 
-			history.messages[parentId].raContent = RAGTemplate(contextString, query);
+			history.messages[parentId].raContent = await RAGTemplate(
+				localStorage.token,
+				contextString,
+				query
+			);
 			history.messages[parentId].contexts = relevantContexts;
 			await tick();
 			processing = '';

+ 5 - 1
src/routes/(app)/c/[id]/+page.svelte

@@ -280,7 +280,11 @@
 
 			console.log(contextString);
 
-			history.messages[parentId].raContent = RAGTemplate(contextString, query);
+			history.messages[parentId].raContent = await RAGTemplate(
+				localStorage.token,
+				contextString,
+				query
+			);
 			history.messages[parentId].contexts = relevantContexts;
 			await tick();
 			processing = '';