Pārlūkot izejas kodu

fix: address comment in pr #1687

Steven Kreitzer 1 gadu atpakaļ
vecāks
revīzija
c9c9660459
4 mainītis faili ar 93 papildinājumiem un 44 dzēšanām
  1. 0 4
      backend/apps/ollama/main.py
  2. 44 38
      backend/apps/rag/main.py
  3. 41 2
      backend/apps/rag/utils.py
  4. 8 0
      backend/config.py

+ 0 - 4
backend/apps/ollama/main.py

@@ -92,10 +92,6 @@ async def get_ollama_api_urls(user=Depends(get_admin_user)):
     return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS}
 
 
-def get_ollama_endpoint(url_idx: int = 0):
-    return app.state.OLLAMA_BASE_URLS[url_idx]
-
-
 class UrlUpdateForm(BaseModel):
     urls: List[str]
 

+ 44 - 38
backend/apps/rag/main.py

@@ -39,8 +39,6 @@ import json
 
 import sentence_transformers
 
-from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm
-
 from apps.web.models.documents import (
     Documents,
     DocumentForm,
@@ -48,6 +46,7 @@ from apps.web.models.documents import (
 )
 
 from apps.rag.utils import (
+    get_model_path,
     query_embeddings_doc,
     query_embeddings_function,
     query_embeddings_collection,
@@ -60,6 +59,7 @@ from utils.misc import (
     extract_folders_after_data_docs,
 )
 from utils.utils import get_current_user, get_admin_user
+
 from config import (
     SRC_LOG_LEVELS,
     UPLOAD_DIR,
@@ -68,8 +68,10 @@ from config import (
     RAG_RELEVANCE_THRESHOLD,
     RAG_EMBEDDING_ENGINE,
     RAG_EMBEDDING_MODEL,
+    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
     RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
     RAG_RERANKING_MODEL,
+    RAG_RERANKING_MODEL_AUTO_UPDATE,
     RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
     RAG_OPENAI_API_BASE_URL,
     RAG_OPENAI_API_KEY,
@@ -87,13 +89,11 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
 
 app = FastAPI()
 
-
 app.state.TOP_K = RAG_TOP_K
 app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
 app.state.CHUNK_SIZE = CHUNK_SIZE
 app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
 
-
 app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
 app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
 app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
@@ -104,27 +104,48 @@ app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
 
 app.state.PDF_EXTRACT_IMAGES = False
 
-if app.state.RAG_EMBEDDING_ENGINE == "":
-    app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
-        app.state.RAG_EMBEDDING_MODEL,
-        device=DEVICE_TYPE,
-        trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
-    )
-else:
-    app.state.sentence_transformer_ef = None
-
-if not app.state.RAG_RERANKING_MODEL == "":
-    app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
-        app.state.RAG_RERANKING_MODEL,
-        device=DEVICE_TYPE,
-        trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
-    )
-else:
-    app.state.sentence_transformer_rf = None
 
+def update_embedding_model(
+    embedding_model: str,
+    update_model: bool = False,
+):
+    if embedding_model and app.state.RAG_EMBEDDING_ENGINE == "":
+        app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
+            get_model_path(embedding_model, update_model),
+            device=DEVICE_TYPE,
+            trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
+        )
+    else:
+        app.state.sentence_transformer_ef = None
+
+
+def update_reranking_model(
+    reranking_model: str,
+    update_model: bool = False,
+):
+    if reranking_model:
+        app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
+            get_model_path(reranking_model, update_model),
+            device=DEVICE_TYPE,
+            trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
+        )
+    else:
+        app.state.sentence_transformer_rf = None
+
+
+update_embedding_model(
+    app.state.RAG_EMBEDDING_MODEL,
+    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
+)
+
+update_reranking_model(
+    app.state.RAG_RERANKING_MODEL,
+    RAG_RERANKING_MODEL_AUTO_UPDATE,
+)
 
 origins = ["*"]
 
+
 app.add_middleware(
     CORSMiddleware,
     allow_origins=origins,
@@ -200,15 +221,7 @@ async def update_embedding_config(
                 app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
                 app.state.OPENAI_API_KEY = form_data.openai_config.key
 
-            app.state.sentence_transformer_ef = None
-        else:
-            app.state.sentence_transformer_ef = (
-                sentence_transformers.SentenceTransformer(
-                    app.state.RAG_EMBEDDING_MODEL,
-                    device=DEVICE_TYPE,
-                    trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
-                )
-            )
+        update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True)
 
         return {
             "status": True,
@@ -219,7 +232,6 @@ async def update_embedding_config(
                 "key": app.state.OPENAI_API_KEY,
             },
         }
-
     except Exception as e:
         log.exception(f"Problem updating embedding model: {e}")
         raise HTTPException(
@@ -242,13 +254,7 @@ async def update_reranking_config(
     try:
         app.state.RAG_RERANKING_MODEL = form_data.reranking_model
 
-        if app.state.RAG_RERANKING_MODEL == "":
-            app.state.sentence_transformer_rf = None
-        else:
-            app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
-                app.state.RAG_RERANKING_MODEL,
-                device=DEVICE_TYPE,
-            )
+        update_reranking_model(app.state.RAG_RERANKING_MODEL, True)
 
         return {
             "status": True,

+ 41 - 2
backend/apps/rag/utils.py

@@ -1,3 +1,4 @@
+import os
 import logging
 import requests
 
@@ -8,6 +9,8 @@ from apps.ollama.main import (
     GenerateEmbeddingsForm,
 )
 
+from huggingface_hub import snapshot_download
+
 from langchain_core.documents import Document
 from langchain_community.retrievers import BM25Retriever
 from langchain.retrievers import (
@@ -282,8 +285,6 @@ def rag_messages(
 
         extracted_collections.extend(collection)
 
-    log.debug(f"relevant_contexts: {relevant_contexts}")
-
     context_string = ""
     for context in relevant_contexts:
         items = context["documents"][0]
@@ -319,6 +320,44 @@ def rag_messages(
     return messages
 
 
+def get_model_path(model: str, update_model: bool = False):
+    # Construct huggingface_hub kwargs with local_files_only to return the snapshot path
+    cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
+
+    local_files_only = not update_model
+
+    snapshot_kwargs = {
+        "cache_dir": cache_dir,
+        "local_files_only": local_files_only,
+    }
+
+    log.debug(f"embedding_model: {model}")
+    log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
+
+    # Inspiration from upstream sentence_transformers
+    if (
+        os.path.exists(model)
+        or ("\\" in model or model.count("/") > 1)
+        and local_files_only
+    ):
+        # If fully qualified path exists, return input, else set repo_id
+        return model
+    elif "/" not in model:
+        # Set valid repo_id for model short-name
+        model = "sentence-transformers" + "/" + model
+
+    snapshot_kwargs["repo_id"] = model
+
+    # Attempt to query the huggingface_hub library to determine the local path and/or to update
+    try:
+        model_repo_path = snapshot_download(**snapshot_kwargs)
+        log.debug(f"model_repo_path: {model_repo_path}")
+        return model_repo_path
+    except Exception as e:
+        log.exception(f"Cannot determine model snapshot path: {e}")
+        return model
+
+
 def generate_openai_embeddings(
     model: str, text: str, key: str, url: str = "https://api.openai.com/v1"
 ):

+ 8 - 0
backend/config.py

@@ -430,6 +430,10 @@ RAG_EMBEDDING_MODEL = os.environ.get(
 )
 log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"),
 
+RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
+    os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true"
+)
+
 RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
     os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
 )
@@ -438,6 +442,10 @@ RAG_RERANKING_MODEL = os.environ.get("RAG_RERANKING_MODEL", "")
 if not RAG_RERANKING_MODEL == "":
     log.info(f"Reranking model set: {RAG_RERANKING_MODEL}"),
 
+RAG_RERANKING_MODEL_AUTO_UPDATE = (
+    os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true"
+)
+
 RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
     os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
 )