소스 검색

Improve embedding model update & resolve network dependency

* Add config variable RAG_EMBEDDING_MODEL_AUTO_UPDATE to control update behavior
* Add RAG utils embedding_model_get_path() function to output the filesystem path in addition to update of the model using huggingface_hub
* Update and utilize existing RAG functions in main: get_embedding_model() & update_embedding_model()
* Add GUI setting to execute manual update process
Self Denial 1 년 전
부모
커밋
3b66aa55c0
5개의 변경된 파일218개의 추가작업 그리고 19개의 파일을 삭제
  1. 33 18
      backend/apps/rag/main.py
  2. 35 0
      backend/apps/rag/utils.py
  3. 3 0
      backend/config.py
  4. 61 0
      src/lib/apis/rag/index.ts
  5. 86 1
      src/lib/components/documents/Settings/General.svelte

+ 33 - 18
backend/apps/rag/main.py

@@ -13,7 +13,6 @@ import os, shutil, logging, re
 from pathlib import Path
 from typing import List
 
-from sentence_transformers import SentenceTransformer
 from chromadb.utils import embedding_functions
 
 from langchain_community.document_loaders import (
@@ -45,7 +44,7 @@ from apps.web.models.documents import (
     DocumentResponse,
 )
 
-from apps.rag.utils import query_doc, query_collection
+from apps.rag.utils import query_doc, query_collection, embedding_model_get_path
 
 from utils.misc import (
     calculate_sha256,
@@ -60,6 +59,7 @@ from config import (
     DOCS_DIR,
     RAG_EMBEDDING_MODEL,
     RAG_EMBEDDING_MODEL_DEVICE_TYPE,
+    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
     CHROMA_CLIENT,
     CHUNK_SIZE,
     CHUNK_OVERLAP,
@@ -71,15 +71,6 @@ from constants import ERROR_MESSAGES
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 
-#
-# if RAG_EMBEDDING_MODEL:
-#    sentence_transformer_ef = SentenceTransformer(
-#        model_name_or_path=RAG_EMBEDDING_MODEL,
-#        cache_folder=RAG_EMBEDDING_MODEL_DIR,
-#        device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
-#    )
-
-
 app = FastAPI()
 
 app.state.PDF_EXTRACT_IMAGES = False
@@ -87,11 +78,12 @@ app.state.CHUNK_SIZE = CHUNK_SIZE
 app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
 app.state.RAG_TEMPLATE = RAG_TEMPLATE
 app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
+app.state.RAG_EMBEDDING_MODEL_PATH = embedding_model_get_path(app.state.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE)
 app.state.TOP_K = 4
 
 app.state.sentence_transformer_ef = (
     embedding_functions.SentenceTransformerEmbeddingFunction(
-        model_name=app.state.RAG_EMBEDDING_MODEL,
+        model_name=app.state.RAG_EMBEDDING_MODEL_PATH,
         device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
     )
 )
@@ -132,6 +124,7 @@ async def get_embedding_model(user=Depends(get_admin_user)):
     return {
         "status": True,
         "embedding_model": app.state.RAG_EMBEDDING_MODEL,
+        "embedding_model_path": app.state.RAG_EMBEDDING_MODEL_PATH,
     }
 
 
@@ -143,17 +136,39 @@ class EmbeddingModelUpdateForm(BaseModel):
 async def update_embedding_model(
     form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
 ):
+    status = True
+    old_model_path = app.state.RAG_EMBEDDING_MODEL_PATH
     app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
-    app.state.sentence_transformer_ef = (
-        embedding_functions.SentenceTransformerEmbeddingFunction(
-            model_name=app.state.RAG_EMBEDDING_MODEL,
-            device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
+
+    log.debug(f"form_data.embedding_model: {form_data.embedding_model}")
+    log.info(f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}")
+
+    try:
+        app.state.RAG_EMBEDDING_MODEL_PATH = embedding_model_get_path(app.state.RAG_EMBEDDING_MODEL, True)
+        app.state.sentence_transformer_ef = (
+            embedding_functions.SentenceTransformerEmbeddingFunction(
+                model_name=app.state.RAG_EMBEDDING_MODEL_PATH,
+                device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
+            )
+        )
+    except Exception as e: 
+        log.exception(f"Problem updating embedding model: {e}")
+        raise HTTPException(
+            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+            detail=e,
         )
-    )
+
+    if app.state.RAG_EMBEDDING_MODEL_PATH == old_model_path:
+      status = False
+
+    log.debug(f"app.state.RAG_EMBEDDING_MODEL_PATH: {app.state.RAG_EMBEDDING_MODEL_PATH}")
+    log.debug(f"old_model_path: {old_model_path}")
+    log.debug(f"status: {status}")
 
     return {
-        "status": True,
+        "status": status,
         "embedding_model": app.state.RAG_EMBEDDING_MODEL,
+        "embedding_model_path": app.state.RAG_EMBEDDING_MODEL_PATH,
     }
 
 

+ 35 - 0
backend/apps/rag/utils.py

@@ -1,6 +1,8 @@
+import os
 import re
 import logging
 from typing import List
+from huggingface_hub import snapshot_download
 
 from config import SRC_LOG_LEVELS, CHROMA_CLIENT
 
@@ -188,3 +190,36 @@ def rag_messages(docs, messages, template, k, embedding_function):
     messages[last_user_message_idx] = new_user_message
 
     return messages
+
+def embedding_model_get_path(embedding_model: str, update_embedding_model: bool = False):
+    # Construct huggingface_hub kwargs with local_files_only to return the snapshot path
+    cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
+    local_files_only = not update_embedding_model
+    snapshot_kwargs = {
+        "cache_dir": cache_dir,
+        "local_files_only": local_files_only,
+    }
+
+    log.debug(f"SENTENCE_TRANSFORMERS_HOME cache_dir: {cache_dir}")
+    log.debug(f"embedding_model: {embedding_model}")
+    log.debug(f"update_embedding_model: {update_embedding_model}")
+    log.debug(f"local_files_only: {local_files_only}")
+
+    # Inspiration from upstream sentence_transformers
+    if (os.path.exists(embedding_model) or ("\\" in embedding_model or embedding_model.count("/") > 1) and local_files_only):
+        # If fully qualified path exists, return input, else set repo_id
+        return embedding_model
+    elif "/" not in embedding_model:
+        # Set valid repo_id for model short-name
+        embedding_model = "sentence-transformers" + "/" + embedding_model
+
+    snapshot_kwargs["repo_id"] = embedding_model
+
+    # Attempt to query the huggingface_hub library to determine the local path and/or to update
+    try:
+        embedding_model_repo_path = snapshot_download(**snapshot_kwargs)
+        log.debug(f"embedding_model_repo_path: {embedding_model_repo_path}")
+        return embedding_model_repo_path
+    except Exception as e:
+        log.exception(f"Cannot determine embedding model snapshot path: {e}")
+        return embedding_model

+ 3 - 0
backend/config.py

@@ -395,6 +395,9 @@ RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2")
 RAG_EMBEDDING_MODEL_DEVICE_TYPE = os.environ.get(
     "RAG_EMBEDDING_MODEL_DEVICE_TYPE", "cpu"
 )
+RAG_EMBEDDING_MODEL_AUTO_UPDATE = False
+if os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true":
+    RAG_EMBEDDING_MODEL_AUTO_UPDATE = True
 CHROMA_CLIENT = chromadb.PersistentClient(
     path=CHROMA_DATA_PATH,
     settings=Settings(allow_reset=True, anonymized_telemetry=False),

+ 61 - 0
src/lib/apis/rag/index.ts

@@ -345,3 +345,64 @@ export const resetVectorDB = async (token: string) => {
 
 	return res;
 };
+
+export const getEmbeddingModel = async (token: string) => {
+	let error = null;
+
+	const res = await fetch(`${RAG_API_BASE_URL}/embedding/model`, {
+		method: 'GET',
+		headers: {
+			'Content-Type': 'application/json',
+			Authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			error = err.detail;
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+type EmbeddingModelUpdateForm = {
+	embedding_model: string;
+};
+
+export const updateEmbeddingModel = async (token: string, payload: EmbeddingModelUpdateForm) => {
+	let error = null;
+
+	const res = await fetch(`${RAG_API_BASE_URL}/embedding/model/update`, {
+		method: 'POST',
+		headers: {
+			'Content-Type': 'application/json',
+			Authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			...payload
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			error = err.detail;
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};

+ 86 - 1
src/lib/components/documents/Settings/General.svelte

@@ -6,7 +6,9 @@
 		getQuerySettings,
 		scanDocs,
 		updateQuerySettings,
-		resetVectorDB
+		resetVectorDB,
+		getEmbeddingModel,
+		updateEmbeddingModel
 	} from '$lib/apis/rag';
 
 	import { documents } from '$lib/stores';
@@ -18,6 +20,7 @@
 	export let saveHandler: Function;
 
 	let loading = false;
+	let loading1 = false;
 
 	let showResetConfirm = false;
 
@@ -30,6 +33,10 @@
 		k: 4
 	};
 
+	let embeddingModel = {
+		embedding_model: '',
+	};
+
 	const scanHandler = async () => {
 		loading = true;
 		const res = await scanDocs(localStorage.token);
@@ -41,6 +48,21 @@
 		}
 	};
 
+	const embeddingModelUpdateHandler = async () => {
+		loading1 = true;
+		const res = await updateEmbeddingModel(localStorage.token, embeddingModel);
+		loading1 = false;
+
+		if (res) {
+			console.log('embeddingModelUpdateHandler:', res);
+			if (res.status == true) {
+				toast.success($i18n.t('Model {{embedding_model}} update complete!', res));
+			} else {
+				toast.error($i18n.t('Model {{embedding_model}} update failed or not required!', res));
+			}
+		}
+	};
+
 	const submitHandler = async () => {
 		const res = await updateRAGConfig(localStorage.token, {
 			pdf_extract_images: pdfExtractImages,
@@ -62,6 +84,8 @@
 			chunkOverlap = res.chunk.chunk_overlap;
 		}
 
+		embeddingModel = await getEmbeddingModel(localStorage.token);
+
 		querySettings = await getQuerySettings(localStorage.token);
 	});
 </script>
@@ -137,6 +161,67 @@
 					{/if}
 				</button>
 			</div>
+
+			<div class="  flex w-full justify-between">
+				<div class=" self-center text-xs font-medium">
+					{$i18n.t('Update embedding model {{embedding_model}}', embeddingModel)}
+				</div>
+
+				<button
+					class=" self-center text-xs p-1 px-3 bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 rounded flex flex-row space-x-1 items-center {loading1
+						? ' cursor-not-allowed'
+						: ''}"
+					on:click={() => {
+						embeddingModelUpdateHandler(embeddingModel);
+						console.log('Update embedding model:', embeddingModel.embedding_model);
+					}}
+					type="button"
+					disabled={loading1}
+				>
+					<div class="self-center font-medium">{$i18n.t('Update')}</div>
+
+					<!-- <svg
+						xmlns="http://www.w3.org/2000/svg"
+						viewBox="0 0 16 16"
+						fill="currentColor"
+						class="w-3 h-3"
+					>
+						<path
+							fill-rule="evenodd"
+							d="M13.836 2.477a.75.75 0 0 1 .75.75v3.182a.75.75 0 0 1-.75.75h-3.182a.75.75 0 0 1 0-1.5h1.37l-.84-.841a4.5 4.5 0 0 0-7.08.932.75.75 0 0 1-1.3-.75 6 6 0 0 1 9.44-1.242l.842.84V3.227a.75.75 0 0 1 .75-.75Zm-.911 7.5A.75.75 0 0 1 13.199 11a6 6 0 0 1-9.44 1.241l-.84-.84v1.371a.75.75 0 0 1-1.5 0V9.591a.75.75 0 0 1 .75-.75H5.35a.75.75 0 0 1 0 1.5H3.98l.841.841a4.5 4.5 0 0 0 7.08-.932.75.75 0 0 1 1.025-.273Z"
+							clip-rule="evenodd"
+						/>
+					</svg> -->
+
+					{#if loading1}
+						<div class="ml-3 self-center">
+							<svg
+								class=" w-3 h-3"
+								viewBox="0 0 24 24"
+								fill="currentColor"
+								xmlns="http://www.w3.org/2000/svg"
+								><style>
+									.spinner_ajPY {
+										transform-origin: center;
+										animation: spinner_AtaB 0.75s infinite linear;
+									}
+									@keyframes spinner_AtaB {
+										100% {
+											transform: rotate(360deg);
+										}
+									}
+								</style><path
+									d="M12,1A11,11,0,1,0,23,12,11,11,0,0,0,12,1Zm0,19a8,8,0,1,1,8-8A8,8,0,0,1,12,20Z"
+									opacity=".25"
+								/><path
+									d="M10.14,1.16a11,11,0,0,0-9,8.92A1.59,1.59,0,0,0,2.46,12,1.52,1.52,0,0,0,4.11,10.7a8,8,0,0,1,6.66-6.61A1.42,1.42,0,0,0,12,2.69h0A1.57,1.57,0,0,0,10.14,1.16Z"
+									class="spinner_ajPY"
+								/></svg
+							>
+						</div>
+					{/if}
+				</button>
+			</div>
 		</div>
 
 		<hr class=" dark:border-gray-700" />