Timothy J. Baek 1 yıl önce
ebeveyn
işleme
abfcceecef

+ 17 - 20
backend/apps/rag/main.py

@@ -142,43 +142,40 @@ class EmbeddingModelUpdateForm(BaseModel):
 async def update_embedding_model(
     form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
 ):
-    status = True
-    old_model_path = app.state.RAG_EMBEDDING_MODEL_PATH
-    app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
 
     log.debug(f"form_data.embedding_model: {form_data.embedding_model}")
     log.info(
         f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
     )
 
+    embedding_model_path = None
+    sentence_transformer_ef = None
     try:
-        app.state.RAG_EMBEDDING_MODEL_PATH = get_embedding_model_path(
-            app.state.RAG_EMBEDDING_MODEL, True
-        )
-        app.state.sentence_transformer_ef = (
-            embedding_functions.SentenceTransformerEmbeddingFunction(
-                model_name=app.state.RAG_EMBEDDING_MODEL_PATH,
-                device=DEVICE_TYPE,
+        embedding_model_path = get_embedding_model_path(form_data.embedding_model, True)
+        if app.state.RAG_EMBEDDING_MODEL_PATH != embedding_model_path:
+            sentence_transformer_ef = (
+                embedding_functions.SentenceTransformerEmbeddingFunction(
+                    model_name=embedding_model_path,
+                    device=DEVICE_TYPE,
+                )
             )
-        )
     except Exception as e:
         log.exception(f"Problem updating embedding model: {e}")
         raise HTTPException(
             status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
-            detail=e,
+            detail=ERROR_MESSAGES.DEFAULT(e),
         )
 
-    if app.state.RAG_EMBEDDING_MODEL_PATH == old_model_path:
-        status = False
+    if sentence_transformer_ef:
+        app.state.RAG_EMBEDDING_MODEL_PATH = embedding_model_path
+        app.state.sentence_transformer_ef = sentence_transformer_ef
 
-    log.debug(
-        f"app.state.RAG_EMBEDDING_MODEL_PATH: {app.state.RAG_EMBEDDING_MODEL_PATH}"
-    )
-    log.debug(f"old_model_path: {old_model_path}")
-    log.debug(f"status: {status}")
+        log.debug(
+            f"app.state.RAG_EMBEDDING_MODEL_PATH: {app.state.RAG_EMBEDDING_MODEL_PATH}"
+        )
 
     return {
-        "status": status,
+        "status": sentence_transformer_ef != None,
         "embedding_model": app.state.RAG_EMBEDDING_MODEL,
         "embedding_model_path": app.state.RAG_EMBEDDING_MODEL_PATH,
     }

+ 11 - 3
src/lib/components/documents/Settings/General.svelte

@@ -35,6 +35,9 @@
 		k: 4
 	};
 
+	let embeddingModelConfig = {
+		embedding_model: ''
+	};
 	let embeddingModel = '';
 
 	const scanHandler = async () => {
@@ -61,7 +64,13 @@
 		console.log('Update embedding model attempt:', embeddingModel);
 
 		updateEmbeddingModelLoading = true;
-		const res = await updateEmbeddingModel(localStorage.token, { embedding_model: embeddingModel });
+		const res = await updateEmbeddingModel(localStorage.token, {
+			embedding_model: embeddingModel
+		}).catch((error) => {
+			toast.error(error);
+			embeddingModel = embeddingModelConfig.embedding_model;
+			return null;
+		});
 		updateEmbeddingModelLoading = false;
 
 		if (res) {
@@ -99,8 +108,7 @@
 			chunkOverlap = res.chunk.chunk_overlap;
 		}
 
-		const embeddingModelConfig = await getEmbeddingModel(localStorage.token);
-
+		embeddingModelConfig = await getEmbeddingModel(localStorage.token);
 		embeddingModel = embeddingModelConfig.embedding_model;
 
 		querySettings = await getQuerySettings(localStorage.token);