Timothy J. Baek 7 months ago
parent
commit
67f95ddfdc
1 changed files with 10 additions and 9 deletions
  1. 10 9
      backend/open_webui/apps/rag/main.py

+ 10 - 9
backend/open_webui/apps/rag/main.py

@@ -180,13 +180,13 @@ app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_
 
 def update_embedding_model(
     embedding_model: str,
-    update_model: bool = False,
+    auto_update: bool = False,
 ):
     if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
         import sentence_transformers
 
         app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
-            get_model_path(embedding_model, update_model),
+            get_model_path(embedding_model, auto_update),
             device=DEVICE_TYPE,
             trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
         )
@@ -196,17 +196,18 @@ def update_embedding_model(
 
 def update_reranking_model(
     reranking_model: str,
-    update_model: bool = False,
+    auto_update: bool = False,
 ):
     if reranking_model:
         if any(model in reranking_model for model in ["jinaai/jina-colbert-v2"]):
 
-            class Colbert:
+            class ColBERT:
                 def __init__(self, name) -> None:
                     self.device = "cuda" if torch.cuda.is_available() else "cpu"
-                    self.ckpt = Checkpoint(name, colbert_config=ColBERTConfig()).to(
-                        self.device
-                    )
+                    self.ckpt = Checkpoint(
+                        get_model_path(name, auto_update),
+                        colbert_config=ColBERTConfig(),
+                    ).to(self.device)
                     pass
 
                 def calculate_similarity_scores(
@@ -264,13 +265,13 @@ def update_reranking_model(
 
                     return scores
 
-            app.state.sentence_transformer_rf = Colbert(reranking_model)
+            app.state.sentence_transformer_rf = ColBERT(reranking_model)
         else:
             import sentence_transformers
 
             try:
                 app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
-                    get_model_path(reranking_model, update_model),
+                    get_model_path(reranking_model, auto_update),
                     device=DEVICE_TYPE,
                     trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
                 )