|
@@ -39,8 +39,6 @@ import json
|
|
|
|
|
|
import sentence_transformers
|
|
import sentence_transformers
|
|
|
|
|
|
-from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm
|
|
|
|
-
|
|
|
|
from apps.web.models.documents import (
|
|
from apps.web.models.documents import (
|
|
Documents,
|
|
Documents,
|
|
DocumentForm,
|
|
DocumentForm,
|
|
@@ -48,6 +46,7 @@ from apps.web.models.documents import (
|
|
)
|
|
)
|
|
|
|
|
|
from apps.rag.utils import (
|
|
from apps.rag.utils import (
|
|
|
|
+ get_model_path,
|
|
query_embeddings_doc,
|
|
query_embeddings_doc,
|
|
query_embeddings_function,
|
|
query_embeddings_function,
|
|
query_embeddings_collection,
|
|
query_embeddings_collection,
|
|
@@ -60,6 +59,7 @@ from utils.misc import (
|
|
extract_folders_after_data_docs,
|
|
extract_folders_after_data_docs,
|
|
)
|
|
)
|
|
from utils.utils import get_current_user, get_admin_user
|
|
from utils.utils import get_current_user, get_admin_user
|
|
|
|
+
|
|
from config import (
|
|
from config import (
|
|
SRC_LOG_LEVELS,
|
|
SRC_LOG_LEVELS,
|
|
UPLOAD_DIR,
|
|
UPLOAD_DIR,
|
|
@@ -68,8 +68,10 @@ from config import (
|
|
RAG_RELEVANCE_THRESHOLD,
|
|
RAG_RELEVANCE_THRESHOLD,
|
|
RAG_EMBEDDING_ENGINE,
|
|
RAG_EMBEDDING_ENGINE,
|
|
RAG_EMBEDDING_MODEL,
|
|
RAG_EMBEDDING_MODEL,
|
|
|
|
+ RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
|
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
|
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
|
RAG_RERANKING_MODEL,
|
|
RAG_RERANKING_MODEL,
|
|
|
|
+ RAG_RERANKING_MODEL_AUTO_UPDATE,
|
|
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
|
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
|
RAG_OPENAI_API_BASE_URL,
|
|
RAG_OPENAI_API_BASE_URL,
|
|
RAG_OPENAI_API_KEY,
|
|
RAG_OPENAI_API_KEY,
|
|
@@ -87,13 +89,11 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
|
|
|
|
|
|
app = FastAPI()
|
|
app = FastAPI()
|
|
|
|
|
|
-
|
|
|
|
app.state.TOP_K = RAG_TOP_K
|
|
app.state.TOP_K = RAG_TOP_K
|
|
app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
|
|
app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
|
|
app.state.CHUNK_SIZE = CHUNK_SIZE
|
|
app.state.CHUNK_SIZE = CHUNK_SIZE
|
|
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
|
|
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
|
|
|
|
|
|
-
|
|
|
|
app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
|
|
app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
|
|
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
|
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
|
app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
|
|
app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
|
|
@@ -104,27 +104,48 @@ app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
|
|
|
|
|
|
app.state.PDF_EXTRACT_IMAGES = False
|
|
app.state.PDF_EXTRACT_IMAGES = False
|
|
|
|
|
|
-if app.state.RAG_EMBEDDING_ENGINE == "":
|
|
|
|
- app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
|
|
|
|
- app.state.RAG_EMBEDDING_MODEL,
|
|
|
|
- device=DEVICE_TYPE,
|
|
|
|
- trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
|
|
|
- )
|
|
|
|
-else:
|
|
|
|
- app.state.sentence_transformer_ef = None
|
|
|
|
-
|
|
|
|
-if not app.state.RAG_RERANKING_MODEL == "":
|
|
|
|
- app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
|
|
|
|
- app.state.RAG_RERANKING_MODEL,
|
|
|
|
- device=DEVICE_TYPE,
|
|
|
|
- trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
|
|
|
- )
|
|
|
|
-else:
|
|
|
|
- app.state.sentence_transformer_rf = None
|
|
|
|
|
|
|
|
|
|
+def update_embedding_model(
|
|
|
|
+ embedding_model: str,
|
|
|
|
+ update_model: bool = False,
|
|
|
|
+):
|
|
|
|
+ if embedding_model and app.state.RAG_EMBEDDING_ENGINE == "":
|
|
|
|
+ app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
|
|
|
|
+ get_model_path(embedding_model, update_model),
|
|
|
|
+ device=DEVICE_TYPE,
|
|
|
|
+ trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
|
|
|
+ )
|
|
|
|
+ else:
|
|
|
|
+ app.state.sentence_transformer_ef = None
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def update_reranking_model(
|
|
|
|
+ reranking_model: str,
|
|
|
|
+ update_model: bool = False,
|
|
|
|
+):
|
|
|
|
+ if reranking_model:
|
|
|
|
+ app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
|
|
|
|
+ get_model_path(reranking_model, update_model),
|
|
|
|
+ device=DEVICE_TYPE,
|
|
|
|
+ trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
|
|
|
+ )
|
|
|
|
+ else:
|
|
|
|
+ app.state.sentence_transformer_rf = None
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+update_embedding_model(
|
|
|
|
+ app.state.RAG_EMBEDDING_MODEL,
|
|
|
|
+ RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
|
|
|
+)
|
|
|
|
+
|
|
|
|
+update_reranking_model(
|
|
|
|
+ app.state.RAG_RERANKING_MODEL,
|
|
|
|
+ RAG_RERANKING_MODEL_AUTO_UPDATE,
|
|
|
|
+)
|
|
|
|
|
|
origins = ["*"]
|
|
origins = ["*"]
|
|
|
|
|
|
|
|
+
|
|
app.add_middleware(
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
CORSMiddleware,
|
|
allow_origins=origins,
|
|
allow_origins=origins,
|
|
@@ -200,15 +221,7 @@ async def update_embedding_config(
|
|
app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
|
|
app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
|
|
app.state.OPENAI_API_KEY = form_data.openai_config.key
|
|
app.state.OPENAI_API_KEY = form_data.openai_config.key
|
|
|
|
|
|
- app.state.sentence_transformer_ef = None
|
|
|
|
- else:
|
|
|
|
- app.state.sentence_transformer_ef = (
|
|
|
|
- sentence_transformers.SentenceTransformer(
|
|
|
|
- app.state.RAG_EMBEDDING_MODEL,
|
|
|
|
- device=DEVICE_TYPE,
|
|
|
|
- trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
|
|
|
- )
|
|
|
|
- )
|
|
|
|
|
|
+ update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True)
|
|
|
|
|
|
return {
|
|
return {
|
|
"status": True,
|
|
"status": True,
|
|
@@ -219,7 +232,6 @@ async def update_embedding_config(
|
|
"key": app.state.OPENAI_API_KEY,
|
|
"key": app.state.OPENAI_API_KEY,
|
|
},
|
|
},
|
|
}
|
|
}
|
|
-
|
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
log.exception(f"Problem updating embedding model: {e}")
|
|
log.exception(f"Problem updating embedding model: {e}")
|
|
raise HTTPException(
|
|
raise HTTPException(
|
|
@@ -242,13 +254,7 @@ async def update_reranking_config(
|
|
try:
|
|
try:
|
|
app.state.RAG_RERANKING_MODEL = form_data.reranking_model
|
|
app.state.RAG_RERANKING_MODEL = form_data.reranking_model
|
|
|
|
|
|
- if app.state.RAG_RERANKING_MODEL == "":
|
|
|
|
- app.state.sentence_transformer_rf = None
|
|
|
|
- else:
|
|
|
|
- app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
|
|
|
|
- app.state.RAG_RERANKING_MODEL,
|
|
|
|
- device=DEVICE_TYPE,
|
|
|
|
- )
|
|
|
|
|
|
+ update_reranking_model(app.state.RAG_RERANKING_MODEL, True)
|
|
|
|
|
|
return {
|
|
return {
|
|
"status": True,
|
|
"status": True,
|