|
@@ -13,7 +13,6 @@ import os, shutil, logging, re
|
|
|
from pathlib import Path
|
|
|
from typing import List
|
|
|
|
|
|
-from sentence_transformers import SentenceTransformer
|
|
|
from chromadb.utils import embedding_functions
|
|
|
|
|
|
from langchain_community.document_loaders import (
|
|
@@ -45,7 +44,7 @@ from apps.web.models.documents import (
|
|
|
DocumentResponse,
|
|
|
)
|
|
|
|
|
|
-from apps.rag.utils import query_doc, query_collection
|
|
|
+from apps.rag.utils import query_doc, query_collection, embedding_model_get_path
|
|
|
|
|
|
from utils.misc import (
|
|
|
calculate_sha256,
|
|
@@ -60,6 +59,7 @@ from config import (
|
|
|
DOCS_DIR,
|
|
|
RAG_EMBEDDING_MODEL,
|
|
|
RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
|
|
+ RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
|
|
CHROMA_CLIENT,
|
|
|
CHUNK_SIZE,
|
|
|
CHUNK_OVERLAP,
|
|
@@ -71,15 +71,6 @@ from constants import ERROR_MESSAGES
|
|
|
log = logging.getLogger(__name__)
|
|
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
|
|
|
|
|
-#
|
|
|
-# if RAG_EMBEDDING_MODEL:
|
|
|
-# sentence_transformer_ef = SentenceTransformer(
|
|
|
-# model_name_or_path=RAG_EMBEDDING_MODEL,
|
|
|
-# cache_folder=RAG_EMBEDDING_MODEL_DIR,
|
|
|
-# device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
|
|
-# )
|
|
|
-
|
|
|
-
|
|
|
app = FastAPI()
|
|
|
|
|
|
app.state.PDF_EXTRACT_IMAGES = False
|
|
@@ -87,11 +78,12 @@ app.state.CHUNK_SIZE = CHUNK_SIZE
|
|
|
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
|
|
|
app.state.RAG_TEMPLATE = RAG_TEMPLATE
|
|
|
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
|
|
+app.state.RAG_EMBEDDING_MODEL_PATH = embedding_model_get_path(app.state.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE)
|
|
|
app.state.TOP_K = 4
|
|
|
|
|
|
app.state.sentence_transformer_ef = (
|
|
|
embedding_functions.SentenceTransformerEmbeddingFunction(
|
|
|
- model_name=app.state.RAG_EMBEDDING_MODEL,
|
|
|
+ model_name=app.state.RAG_EMBEDDING_MODEL_PATH,
|
|
|
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
|
|
)
|
|
|
)
|
|
@@ -132,6 +124,7 @@ async def get_embedding_model(user=Depends(get_admin_user)):
|
|
|
return {
|
|
|
"status": True,
|
|
|
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
|
|
+ "embedding_model_path": app.state.RAG_EMBEDDING_MODEL_PATH,
|
|
|
}
|
|
|
|
|
|
|
|
@@ -143,17 +136,39 @@ class EmbeddingModelUpdateForm(BaseModel):
|
|
|
async def update_embedding_model(
|
|
|
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
|
|
|
):
|
|
|
+ status = True
|
|
|
+ old_model_path = app.state.RAG_EMBEDDING_MODEL_PATH
|
|
|
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
|
|
- app.state.sentence_transformer_ef = (
|
|
|
- embedding_functions.SentenceTransformerEmbeddingFunction(
|
|
|
- model_name=app.state.RAG_EMBEDDING_MODEL,
|
|
|
- device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
|
|
+
|
|
|
+ log.debug(f"form_data.embedding_model: {form_data.embedding_model}")
|
|
|
+ log.info(f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}")
|
|
|
+
|
|
|
+ try:
|
|
|
+ app.state.RAG_EMBEDDING_MODEL_PATH = embedding_model_get_path(app.state.RAG_EMBEDDING_MODEL, True)
|
|
|
+ app.state.sentence_transformer_ef = (
|
|
|
+ embedding_functions.SentenceTransformerEmbeddingFunction(
|
|
|
+ model_name=app.state.RAG_EMBEDDING_MODEL_PATH,
|
|
|
+ device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
|
|
+ )
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ log.exception(f"Problem updating embedding model: {e}")
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
|
+ detail=e,
|
|
|
)
|
|
|
- )
|
|
|
+
|
|
|
+ if app.state.RAG_EMBEDDING_MODEL_PATH == old_model_path:
|
|
|
+ status = False
|
|
|
+
|
|
|
+ log.debug(f"app.state.RAG_EMBEDDING_MODEL_PATH: {app.state.RAG_EMBEDDING_MODEL_PATH}")
|
|
|
+ log.debug(f"old_model_path: {old_model_path}")
|
|
|
+ log.debug(f"status: {status}")
|
|
|
|
|
|
return {
|
|
|
- "status": True,
|
|
|
+ "status": status,
|
|
|
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
|
|
+ "embedding_model_path": app.state.RAG_EMBEDDING_MODEL_PATH,
|
|
|
}
|
|
|
|
|
|
|