Browse Source

feat: frontend integration

Timothy J. Baek 1 year ago
parent
commit
9cdb5bf9fe

+ 22 - 12
backend/apps/rag/main.py

@@ -138,20 +138,22 @@ async def get_status():
     }
     }
 
 
 
 
-@app.get("/embedding/model")
-async def get_embedding_model(user=Depends(get_admin_user)):
+@app.get("/embedding")
+async def get_embedding_config(user=Depends(get_admin_user)):
     return {
     return {
         "status": True,
         "status": True,
+        "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
         "embedding_model": app.state.RAG_EMBEDDING_MODEL,
         "embedding_model": app.state.RAG_EMBEDDING_MODEL,
     }
     }
 
 
 
 
 class EmbeddingModelUpdateForm(BaseModel):
 class EmbeddingModelUpdateForm(BaseModel):
+    embedding_engine: str
     embedding_model: str
     embedding_model: str
 
 
 
 
-@app.post("/embedding/model/update")
-async def update_embedding_model(
+@app.post("/embedding/update")
+async def update_embedding_config(
     form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
     form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
 ):
 ):
 
 
@@ -160,18 +162,26 @@ async def update_embedding_model(
     )
     )
 
 
     try:
     try:
-        sentence_transformer_ef = (
-            embedding_functions.SentenceTransformerEmbeddingFunction(
-                model_name=get_embedding_model_path(form_data.embedding_model, True),
-                device=DEVICE_TYPE,
-            )
-        )
+        app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
 
 
-        app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
-        app.state.sentence_transformer_ef = sentence_transformer_ef
+        if app.state.RAG_EMBEDDING_ENGINE == "ollama":
+            app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
+            app.state.sentence_transformer_ef = None
+        else:
+            sentence_transformer_ef = (
+                embedding_functions.SentenceTransformerEmbeddingFunction(
+                    model_name=get_embedding_model_path(
+                        form_data.embedding_model, True
+                    ),
+                    device=DEVICE_TYPE,
+                )
+            )
+            app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
+            app.state.sentence_transformer_ef = sentence_transformer_ef
 
 
         return {
         return {
             "status": True,
             "status": True,
+            "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
             "embedding_model": app.state.RAG_EMBEDDING_MODEL,
             "embedding_model": app.state.RAG_EMBEDDING_MODEL,
         }
         }
 
 

+ 5 - 4
src/lib/apis/rag/index.ts

@@ -346,10 +346,10 @@ export const resetVectorDB = async (token: string) => {
 	return res;
 	return res;
 };
 };
 
 
-export const getEmbeddingModel = async (token: string) => {
+export const getEmbeddingConfig = async (token: string) => {
 	let error = null;
 	let error = null;
 
 
-	const res = await fetch(`${RAG_API_BASE_URL}/embedding/model`, {
+	const res = await fetch(`${RAG_API_BASE_URL}/embedding`, {
 		method: 'GET',
 		method: 'GET',
 		headers: {
 		headers: {
 			'Content-Type': 'application/json',
 			'Content-Type': 'application/json',
@@ -374,13 +374,14 @@ export const getEmbeddingModel = async (token: string) => {
 };
 };
 
 
 type EmbeddingModelUpdateForm = {
 type EmbeddingModelUpdateForm = {
+	embedding_engine: string;
 	embedding_model: string;
 	embedding_model: string;
 };
 };
 
 
-export const updateEmbeddingModel = async (token: string, payload: EmbeddingModelUpdateForm) => {
+export const updateEmbeddingConfig = async (token: string, payload: EmbeddingModelUpdateForm) => {
 	let error = null;
 	let error = null;
 
 
-	const res = await fetch(`${RAG_API_BASE_URL}/embedding/model/update`, {
+	const res = await fetch(`${RAG_API_BASE_URL}/embedding/update`, {
 		method: 'POST',
 		method: 'POST',
 		headers: {
 		headers: {
 			'Content-Type': 'application/json',
 			'Content-Type': 'application/json',

+ 107 - 17
src/lib/components/documents/Settings/General.svelte

@@ -7,11 +7,11 @@
 		scanDocs,
 		scanDocs,
 		updateQuerySettings,
 		updateQuerySettings,
 		resetVectorDB,
 		resetVectorDB,
-		getEmbeddingModel,
-		updateEmbeddingModel
+		getEmbeddingConfig,
+		updateEmbeddingConfig
 	} from '$lib/apis/rag';
 	} from '$lib/apis/rag';
 
 
-	import { documents } from '$lib/stores';
+	import { documents, models } from '$lib/stores';
 	import { onMount, getContext } from 'svelte';
 	import { onMount, getContext } from 'svelte';
 	import { toast } from 'svelte-sonner';
 	import { toast } from 'svelte-sonner';
 
 
@@ -27,6 +27,8 @@
 	let showResetConfirm = false;
 	let showResetConfirm = false;
 
 
 	let embeddingEngine = '';
 	let embeddingEngine = '';
+	let embeddingModel = '';
+
 	let chunkSize = 0;
 	let chunkSize = 0;
 	let chunkOverlap = 0;
 	let chunkOverlap = 0;
 	let pdfExtractImages = true;
 	let pdfExtractImages = true;
@@ -36,8 +38,6 @@
 		k: 4
 		k: 4
 	};
 	};
 
 
-	let embeddingModel = '';
-
 	const scanHandler = async () => {
 	const scanHandler = async () => {
 		scanDirLoading = true;
 		scanDirLoading = true;
 		const res = await scanDocs(localStorage.token);
 		const res = await scanDocs(localStorage.token);
@@ -50,7 +50,16 @@
 	};
 	};
 
 
 	const embeddingModelUpdateHandler = async () => {
 	const embeddingModelUpdateHandler = async () => {
-		if (embeddingModel.split('/').length - 1 > 1) {
+		if (embeddingModel === '') {
+			toast.error(
+				$i18n.t(
+					'Model filesystem path detected. Model shortname is required for update, cannot continue.'
+				)
+			);
+			return;
+		}
+
+		if (embeddingEngine === '' && embeddingModel.split('/').length - 1 > 1) {
 			toast.error(
 			toast.error(
 				$i18n.t(
 				$i18n.t(
 					'Model filesystem path detected. Model shortname is required for update, cannot continue.'
 					'Model filesystem path detected. Model shortname is required for update, cannot continue.'
@@ -62,11 +71,17 @@
 		console.log('Update embedding model attempt:', embeddingModel);
 		console.log('Update embedding model attempt:', embeddingModel);
 
 
 		updateEmbeddingModelLoading = true;
 		updateEmbeddingModelLoading = true;
-		const res = await updateEmbeddingModel(localStorage.token, {
+		const res = await updateEmbeddingConfig(localStorage.token, {
+			embedding_engine: embeddingEngine,
 			embedding_model: embeddingModel
 			embedding_model: embeddingModel
 		}).catch(async (error) => {
 		}).catch(async (error) => {
 			toast.error(error);
 			toast.error(error);
-			embeddingModel = (await getEmbeddingModel(localStorage.token)).embedding_model;
+
+			const embeddingConfig = await getEmbeddingConfig(localStorage.token);
+			if (embeddingConfig) {
+				embeddingEngine = embeddingConfig.embedding_engine;
+				embeddingModel = embeddingConfig.embedding_model;
+			}
 			return null;
 			return null;
 		});
 		});
 		updateEmbeddingModelLoading = false;
 		updateEmbeddingModelLoading = false;
@@ -102,7 +117,12 @@
 			chunkOverlap = res.chunk.chunk_overlap;
 			chunkOverlap = res.chunk.chunk_overlap;
 		}
 		}
 
 
-		embeddingModel = (await getEmbeddingModel(localStorage.token)).embedding_model;
+		const embeddingConfig = await getEmbeddingConfig(localStorage.token);
+
+		if (embeddingConfig) {
+			embeddingEngine = embeddingConfig.embedding_engine;
+			embeddingModel = embeddingConfig.embedding_model;
+		}
 
 
 		querySettings = await getQuerySettings(localStorage.token);
 		querySettings = await getQuerySettings(localStorage.token);
 	});
 	});
@@ -126,6 +146,9 @@
 						class="dark:bg-gray-900 w-fit pr-8 rounded px-2 p-1 text-xs bg-transparent outline-none text-right"
 						class="dark:bg-gray-900 w-fit pr-8 rounded px-2 p-1 text-xs bg-transparent outline-none text-right"
 						bind:value={embeddingEngine}
 						bind:value={embeddingEngine}
 						placeholder="Select an embedding engine"
 						placeholder="Select an embedding engine"
+						on:change={() => {
+							embeddingModel = '';
+						}}
 					>
 					>
 						<option value="">{$i18n.t('Default (SentenceTransformer)')}</option>
 						<option value="">{$i18n.t('Default (SentenceTransformer)')}</option>
 						<option value="ollama">{$i18n.t('Ollama')}</option>
 						<option value="ollama">{$i18n.t('Ollama')}</option>
@@ -136,10 +159,77 @@
 
 
 		<div class="space-y-2">
 		<div class="space-y-2">
 			<div>
 			<div>
+				<div class=" mb-2 text-sm font-medium">{$i18n.t('Update Embedding Model')}</div>
+
 				{#if embeddingEngine === 'ollama'}
 				{#if embeddingEngine === 'ollama'}
-					<div>da</div>
+					<div class="flex w-full">
+						<div class="flex-1 mr-2">
+							<select
+								class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+								bind:value={embeddingModel}
+								placeholder={$i18n.t('Select a model')}
+								required
+							>
+								{#if !embeddingModel}
+									<option value="" disabled selected>{$i18n.t('Select a model')}</option>
+								{/if}
+								{#each $models.filter((m) => m.id && !m.external) as model}
+									<option value={model.name} class="bg-gray-100 dark:bg-gray-700"
+										>{model.name + ' (' + (model.size / 1024 ** 3).toFixed(1) + ' GB)'}</option
+									>
+								{/each}
+							</select>
+						</div>
+						<button
+							class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
+							on:click={() => {
+								embeddingModelUpdateHandler();
+							}}
+							disabled={updateEmbeddingModelLoading}
+						>
+							{#if updateEmbeddingModelLoading}
+								<div class="self-center">
+									<svg
+										class=" w-4 h-4"
+										viewBox="0 0 24 24"
+										fill="currentColor"
+										xmlns="http://www.w3.org/2000/svg"
+										><style>
+											.spinner_ajPY {
+												transform-origin: center;
+												animation: spinner_AtaB 0.75s infinite linear;
+											}
+											@keyframes spinner_AtaB {
+												100% {
+													transform: rotate(360deg);
+												}
+											}
+										</style><path
+											d="M12,1A11,11,0,1,0,23,12,11,11,0,0,0,12,1Zm0,19a8,8,0,1,1,8-8A8,8,0,0,1,12,20Z"
+											opacity=".25"
+										/><path
+											d="M10.14,1.16a11,11,0,0,0-9,8.92A1.59,1.59,0,0,0,2.46,12,1.52,1.52,0,0,0,4.11,10.7a8,8,0,0,1,6.66-6.61A1.42,1.42,0,0,0,12,2.69h0A1.57,1.57,0,0,0,10.14,1.16Z"
+											class="spinner_ajPY"
+										/></svg
+									>
+								</div>
+							{:else}
+								<svg
+									xmlns="http://www.w3.org/2000/svg"
+									viewBox="0 0 16 16"
+									fill="currentColor"
+									class="w-4 h-4"
+								>
+									<path
+										fill-rule="evenodd"
+										d="M12.416 3.376a.75.75 0 0 1 .208 1.04l-5 7.5a.75.75 0 0 1-1.154.114l-3-3a.75.75 0 0 1 1.06-1.06l2.353 2.353 4.493-6.74a.75.75 0 0 1 1.04-.207Z"
+										clip-rule="evenodd"
+									/>
+								</svg>
+							{/if}
+						</button>
+					</div>
 				{:else}
 				{:else}
-					<div class=" mb-2 text-sm font-medium">{$i18n.t('Update Embedding Model')}</div>
 					<div class="flex w-full">
 					<div class="flex w-full">
 						<div class="flex-1 mr-2">
 						<div class="flex-1 mr-2">
 							<input
 							<input
@@ -200,14 +290,14 @@
 							{/if}
 							{/if}
 						</button>
 						</button>
 					</div>
 					</div>
-
-					<div class="mt-2 mb-1 text-xs text-gray-400 dark:text-gray-500">
-						{$i18n.t(
-							'Warning: If you update or change your embedding model, you will need to re-import all documents.'
-						)}
-					</div>
 				{/if}
 				{/if}
 
 
+				<div class="mt-2 mb-1 text-xs text-gray-400 dark:text-gray-500">
+					{$i18n.t(
+						'Warning: If you update or change your embedding model, you will need to re-import all documents.'
+					)}
+				</div>
+
 				<hr class=" dark:border-gray-700 my-3" />
 				<hr class=" dark:border-gray-700 my-3" />
 
 
 				<div class="  flex w-full justify-between">
 				<div class="  flex w-full justify-between">