Timothy Jaeryang Baek il y a 4 mois
Parent
commit
ccdf51588e
2 fichiers modifiés avec 70 ajouts et 68 suppressions
  1. 36 40
      backend/open_webui/main.py
  2. 34 28
      backend/open_webui/routers/retrieval.py

+ 36 - 40
backend/open_webui/main.py

@@ -39,6 +39,13 @@ from starlette.middleware.sessions import SessionMiddleware
 from starlette.responses import Response, StreamingResponse
 
 
+from open_webui.socket.main import (
+    app as socket_app,
+    periodic_usage_pool_cleanup,
+    get_event_call,
+    get_event_emitter,
+)
+
 from open_webui.routers import (
     audio,
     images,
@@ -63,35 +70,19 @@ from open_webui.routers import (
     users,
     utils,
 )
-from open_webui.retrieval.utils import get_sources_from_files
 from open_webui.routers.retrieval import (
     get_embedding_function,
-    update_embedding_model,
-    update_reranking_model,
+    get_ef,
+    get_rf,
 )
+from open_webui.retrieval.utils import get_sources_from_files
 
 
-from open_webui.socket.main import (
-    app as socket_app,
-    periodic_usage_pool_cleanup,
-    get_event_call,
-    get_event_emitter,
-)
-
 from open_webui.internal.db import Session
 
-
-from open_webui.routers.webui import (
-    app as webui_app,
-    generate_function_chat_completion,
-    get_all_models as get_open_webui_models,
-)
-
-
 from open_webui.models.functions import Functions
 from open_webui.models.models import Models
 from open_webui.models.users import UserModel, Users
-from open_webui.utils.plugin import load_function_module_by_id
 
 
 from open_webui.constants import TASKS
@@ -279,7 +270,7 @@ from open_webui.env import (
     OFFLINE_MODE,
 )
 
-
+from open_webui.utils.plugin import load_function_module_by_id
 from open_webui.utils.misc import (
     add_or_update_system_message,
     get_last_user_message,
@@ -528,8 +519,8 @@ app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
 app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
 
 app.state.EMBEDDING_FUNCTION = None
-app.state.sentence_transformer_ef = None
-app.state.sentence_transformer_rf = None
+app.state.ef = None
+app.state.rf = None
 
 app.state.YOUTUBE_LOADER_TRANSLATION = None
 
@@ -537,29 +528,34 @@ app.state.YOUTUBE_LOADER_TRANSLATION = None
 app.state.EMBEDDING_FUNCTION = get_embedding_function(
     app.state.config.RAG_EMBEDDING_ENGINE,
     app.state.config.RAG_EMBEDDING_MODEL,
-    app.state.sentence_transformer_ef,
+    app.state.ef,
     (
-        app.state.config.OPENAI_API_BASE_URL
+        app.state.config.RAG_OPENAI_API_BASE_URL
         if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
-        else app.state.config.OLLAMA_BASE_URL
+        else app.state.config.RAG_OLLAMA_BASE_URL
     ),
     (
-        app.state.config.OPENAI_API_KEY
+        app.state.config.RAG_OPENAI_API_KEY
         if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
-        else app.state.config.OLLAMA_API_KEY
+        else app.state.config.RAG_OLLAMA_API_KEY
     ),
     app.state.config.RAG_EMBEDDING_BATCH_SIZE,
 )
 
-update_embedding_model(
-    app.state.config.RAG_EMBEDDING_MODEL,
-    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
-)
+try:
+    app.state.ef = get_ef(
+        app.state.config.RAG_EMBEDDING_ENGINE,
+        app.state.config.RAG_EMBEDDING_MODEL,
+        RAG_EMBEDDING_MODEL_AUTO_UPDATE,
+    )
 
-update_reranking_model(
-    app.state.config.RAG_RERANKING_MODEL,
-    RAG_RERANKING_MODEL_AUTO_UPDATE,
-)
+    app.state.rf = get_rf(
+        app.state.config.RAG_RERANKING_MODEL,
+        RAG_RERANKING_MODEL_AUTO_UPDATE,
+    )
+except Exception as e:
+    log.error(f"Error updating models: {e}")
+    pass
 
 
 ########################################
@@ -990,11 +986,11 @@ async def chat_completion_files_handler(
         sources = get_sources_from_files(
             files=files,
             queries=queries,
-            embedding_function=retrieval_app.state.EMBEDDING_FUNCTION,
-            k=retrieval_app.state.config.TOP_K,
-            reranking_function=retrieval_app.state.sentence_transformer_rf,
-            r=retrieval_app.state.config.RELEVANCE_THRESHOLD,
-            hybrid_search=retrieval_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
+            embedding_function=app.state.EMBEDDING_FUNCTION,
+            k=app.state.config.TOP_K,
+            reranking_function=app.state.rf,
+            r=app.state.config.RELEVANCE_THRESHOLD,
+            hybrid_search=app.state.config.ENABLE_RAG_HYBRID_SEARCH,
         )
 
         log.debug(f"rag_contexts:sources: {sources}")

+ 34 - 28
backend/open_webui/routers/retrieval.py

@@ -97,62 +97,58 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
 ##########################################
 
 
-def update_embedding_model(
-    request: Request,
+def get_ef(
+    engine: str,
     embedding_model: str,
     auto_update: bool = False,
 ):
-    if embedding_model and request.app.state.config.RAG_EMBEDDING_ENGINE == "":
+    ef = None
+    if embedding_model and engine == "":
         from sentence_transformers import SentenceTransformer
 
         try:
-            request.app.state.sentence_transformer_ef = SentenceTransformer(
+            ef = SentenceTransformer(
                 get_model_path(embedding_model, auto_update),
                 device=DEVICE_TYPE,
                 trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
             )
         except Exception as e:
             log.debug(f"Error loading SentenceTransformer: {e}")
-            request.app.state.sentence_transformer_ef = None
-    else:
-        request.app.state.sentence_transformer_ef = None
 
+    return ef
 
-def update_reranking_model(
-    request: Request,
+
+def get_rf(
     reranking_model: str,
     auto_update: bool = False,
 ):
+    rf = None
     if reranking_model:
         if any(model in reranking_model for model in ["jinaai/jina-colbert-v2"]):
             try:
                 from open_webui.retrieval.models.colbert import ColBERT
 
-                request.app.state.sentence_transformer_rf = ColBERT(
+                rf = ColBERT(
                     get_model_path(reranking_model, auto_update),
                     env="docker" if DOCKER else None,
                 )
+
             except Exception as e:
                 log.error(f"ColBERT: {e}")
-                request.app.state.sentence_transformer_rf = None
-                request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
+                raise Exception(ERROR_MESSAGES.DEFAULT(e))
         else:
             import sentence_transformers
 
             try:
-                request.app.state.sentence_transformer_rf = (
-                    sentence_transformers.CrossEncoder(
-                        get_model_path(reranking_model, auto_update),
-                        device=DEVICE_TYPE,
-                        trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
-                    )
+                rf = sentence_transformers.CrossEncoder(
+                    get_model_path(reranking_model, auto_update),
+                    device=DEVICE_TYPE,
+                    trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
                 )
             except:
                 log.error("CrossEncoder error")
-                request.app.state.sentence_transformer_rf = None
-                request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
-    else:
-        request.app.state.sentence_transformer_rf = None
+                raise Exception(ERROR_MESSAGES.DEFAULT("CrossEncoder error"))
+    return rf
 
 
 ##########################################
@@ -261,12 +257,15 @@ async def update_embedding_config(
                 form_data.embedding_batch_size
             )
 
-        update_embedding_model(request.app.state.config.RAG_EMBEDDING_MODEL)
+        request.app.state.ef = get_ef(
+            request.app.state.config.RAG_EMBEDDING_ENGINE,
+            request.app.state.config.RAG_EMBEDDING_MODEL,
+        )
 
         request.app.state.EMBEDDING_FUNCTION = get_embedding_function(
             request.app.state.config.RAG_EMBEDDING_ENGINE,
             request.app.state.config.RAG_EMBEDDING_MODEL,
-            request.app.state.sentence_transformer_ef,
+            request.app.state.ef,
             (
                 request.app.state.config.OPENAI_API_BASE_URL
                 if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
@@ -316,7 +315,14 @@ async def update_reranking_config(
     try:
         request.app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
 
-        update_reranking_model(request.app.state.config.RAG_RERANKING_MODEL, True)
+        try:
+            request.app.state.rf = get_rf(
+                request.app.state.config.RAG_RERANKING_MODEL,
+                True,
+            )
+        except Exception as e:
+            log.error(f"Error loading reranking model: {e}")
+            request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
 
         return {
             "status": True,
@@ -739,7 +745,7 @@ def save_docs_to_vector_db(
         embedding_function = get_embedding_function(
             request.app.state.config.RAG_EMBEDDING_ENGINE,
             request.app.state.config.RAG_EMBEDDING_MODEL,
-            request.app.state.sentence_transformer_ef,
+            request.app.state.ef,
             (
                 request.app.state.config.OPENAI_API_BASE_URL
                 if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
@@ -1286,7 +1292,7 @@ def query_doc_handler(
                 query=form_data.query,
                 embedding_function=request.app.state.EMBEDDING_FUNCTION,
                 k=form_data.k if form_data.k else request.app.state.config.TOP_K,
-                reranking_function=request.app.state.sentence_transformer_rf,
+                reranking_function=request.app.state.rf,
                 r=(
                     form_data.r
                     if form_data.r
@@ -1328,7 +1334,7 @@ def query_collection_handler(
                 queries=[form_data.query],
                 embedding_function=request.app.state.EMBEDDING_FUNCTION,
                 k=form_data.k if form_data.k else request.app.state.config.TOP_K,
-                reranking_function=request.app.state.sentence_transformer_rf,
+                reranking_function=request.app.state.rf,
                 r=(
                     form_data.r
                     if form_data.r