|
@@ -97,62 +97,58 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
|
|
|
##########################################
|
|
|
|
|
|
|
|
|
-def update_embedding_model(
|
|
|
- request: Request,
|
|
|
+def get_ef(
|
|
|
+ engine: str,
|
|
|
embedding_model: str,
|
|
|
auto_update: bool = False,
|
|
|
):
|
|
|
- if embedding_model and request.app.state.config.RAG_EMBEDDING_ENGINE == "":
|
|
|
+ ef = None
|
|
|
+ if embedding_model and engine == "":
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
try:
|
|
|
- request.app.state.sentence_transformer_ef = SentenceTransformer(
|
|
|
+ ef = SentenceTransformer(
|
|
|
get_model_path(embedding_model, auto_update),
|
|
|
device=DEVICE_TYPE,
|
|
|
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
|
|
)
|
|
|
except Exception as e:
|
|
|
log.debug(f"Error loading SentenceTransformer: {e}")
|
|
|
- request.app.state.sentence_transformer_ef = None
|
|
|
- else:
|
|
|
- request.app.state.sentence_transformer_ef = None
|
|
|
|
|
|
+ return ef
|
|
|
|
|
|
-def update_reranking_model(
|
|
|
- request: Request,
|
|
|
+
|
|
|
+def get_rf(
|
|
|
reranking_model: str,
|
|
|
auto_update: bool = False,
|
|
|
):
|
|
|
+ rf = None
|
|
|
if reranking_model:
|
|
|
if any(model in reranking_model for model in ["jinaai/jina-colbert-v2"]):
|
|
|
try:
|
|
|
from open_webui.retrieval.models.colbert import ColBERT
|
|
|
|
|
|
- request.app.state.sentence_transformer_rf = ColBERT(
|
|
|
+ rf = ColBERT(
|
|
|
get_model_path(reranking_model, auto_update),
|
|
|
env="docker" if DOCKER else None,
|
|
|
)
|
|
|
+
|
|
|
except Exception as e:
|
|
|
log.error(f"ColBERT: {e}")
|
|
|
- request.app.state.sentence_transformer_rf = None
|
|
|
- request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
|
|
|
+ raise Exception(ERROR_MESSAGES.DEFAULT(e))
|
|
|
else:
|
|
|
import sentence_transformers
|
|
|
|
|
|
try:
|
|
|
- request.app.state.sentence_transformer_rf = (
|
|
|
- sentence_transformers.CrossEncoder(
|
|
|
- get_model_path(reranking_model, auto_update),
|
|
|
- device=DEVICE_TYPE,
|
|
|
- trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
|
|
- )
|
|
|
+ rf = sentence_transformers.CrossEncoder(
|
|
|
+ get_model_path(reranking_model, auto_update),
|
|
|
+ device=DEVICE_TYPE,
|
|
|
+ trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
|
|
)
|
|
|
except:
|
|
|
log.error("CrossEncoder error")
|
|
|
- request.app.state.sentence_transformer_rf = None
|
|
|
- request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
|
|
|
- else:
|
|
|
- request.app.state.sentence_transformer_rf = None
|
|
|
+ raise Exception(ERROR_MESSAGES.DEFAULT("CrossEncoder error"))
|
|
|
+ return rf
|
|
|
|
|
|
|
|
|
##########################################
|
|
@@ -261,12 +257,15 @@ async def update_embedding_config(
|
|
|
form_data.embedding_batch_size
|
|
|
)
|
|
|
|
|
|
- update_embedding_model(request.app.state.config.RAG_EMBEDDING_MODEL)
|
|
|
+ request.app.state.ef = get_ef(
|
|
|
+ request.app.state.config.RAG_EMBEDDING_ENGINE,
|
|
|
+ request.app.state.config.RAG_EMBEDDING_MODEL,
|
|
|
+ )
|
|
|
|
|
|
request.app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
|
|
request.app.state.config.RAG_EMBEDDING_ENGINE,
|
|
|
request.app.state.config.RAG_EMBEDDING_MODEL,
|
|
|
- request.app.state.sentence_transformer_ef,
|
|
|
+ request.app.state.ef,
|
|
|
(
|
|
|
request.app.state.config.OPENAI_API_BASE_URL
|
|
|
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
|
@@ -316,7 +315,14 @@ async def update_reranking_config(
|
|
|
try:
|
|
|
request.app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
|
|
|
|
|
|
- update_reranking_model(request.app.state.config.RAG_RERANKING_MODEL, True)
|
|
|
+ try:
|
|
|
+ request.app.state.rf = get_rf(
|
|
|
+ request.app.state.config.RAG_RERANKING_MODEL,
|
|
|
+ True,
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ log.error(f"Error loading reranking model: {e}")
|
|
|
+ request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
|
|
|
|
|
|
return {
|
|
|
"status": True,
|
|
@@ -739,7 +745,7 @@ def save_docs_to_vector_db(
|
|
|
embedding_function = get_embedding_function(
|
|
|
request.app.state.config.RAG_EMBEDDING_ENGINE,
|
|
|
request.app.state.config.RAG_EMBEDDING_MODEL,
|
|
|
- request.app.state.sentence_transformer_ef,
|
|
|
+ request.app.state.ef,
|
|
|
(
|
|
|
request.app.state.config.OPENAI_API_BASE_URL
|
|
|
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
|
@@ -1286,7 +1292,7 @@ def query_doc_handler(
|
|
|
query=form_data.query,
|
|
|
embedding_function=request.app.state.EMBEDDING_FUNCTION,
|
|
|
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
|
|
- reranking_function=request.app.state.sentence_transformer_rf,
|
|
|
+ reranking_function=request.app.state.rf,
|
|
|
r=(
|
|
|
form_data.r
|
|
|
if form_data.r
|
|
@@ -1328,7 +1334,7 @@ def query_collection_handler(
|
|
|
queries=[form_data.query],
|
|
|
embedding_function=request.app.state.EMBEDDING_FUNCTION,
|
|
|
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
|
|
- reranking_function=request.app.state.sentence_transformer_rf,
|
|
|
+ reranking_function=request.app.state.rf,
|
|
|
r=(
|
|
|
form_data.r
|
|
|
if form_data.r
|