Browse Source

refac: more descriptive var names

Timothy J. Baek 1 year ago
parent
commit
0cb0358485
3 changed files with 26 additions and 21 deletions
  1. 2 2
      Dockerfile
  2. 23 18
      backend/apps/rag/main.py
  3. 1 1
      backend/config.py

+ 2 - 2
Dockerfile

@@ -38,7 +38,7 @@ ENV WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models"
 # Leaderboard: https://huggingface.co/spaces/mteb/leaderboard 
 # for better persormance and multilangauge support use "intfloat/multilingual-e5-large"
 # IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
-ENV DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL="all-MiniLM-L6-v2"
+ENV RAG_EMBEDDING_MODEL="all-MiniLM-L6-v2"
 
 WORKDIR /app/backend
 
@@ -55,7 +55,7 @@ RUN apt-get update \
     && rm -rf /var/lib/apt/lists/*
 
 # preload embedding model
-RUN python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL'])"
+RUN python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'])"
 # preload tts model
 RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"
 

+ 23 - 18
backend/apps/rag/main.py

@@ -51,7 +51,7 @@ from utils.utils import get_current_user, get_admin_user
 from config import (
     UPLOAD_DIR,
     DOCS_DIR,
-    SENTENCE_TRANSFORMER_EMBED_MODEL,
+    RAG_EMBEDDING_MODEL,
     CHROMA_CLIENT,
     CHUNK_SIZE,
     CHUNK_OVERLAP,
@@ -60,7 +60,11 @@ from config import (
 
 from constants import ERROR_MESSAGES
 
-sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=SENTENCE_TRANSFORMER_EMBED_MODEL)
+
+if RAG_EMBEDDING_MODEL:
+    sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
+        model_name=RAG_EMBEDDING_MODEL
+    )
 
 app = FastAPI()
 
@@ -98,17 +102,18 @@ def store_data_in_vector_db(data, collection_name) -> bool:
     metadatas = [doc.metadata for doc in docs]
 
     try:
-        if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ:
-    # if you use docker use the model from the environment variable
-            collection = CHROMA_CLIENT.create_collection(name=collection_name, embedding_function=sentence_transformer_ef)
-        
+        if RAG_EMBEDDING_MODEL:
+            # if you use docker use the model from the environment variable
+            collection = CHROMA_CLIENT.create_collection(
+                name=collection_name, embedding_function=sentence_transformer_ef
+            )
         else:
-    # for local development use the default model
+            # for local development use the default model
             collection = CHROMA_CLIENT.create_collection(name=collection_name)
 
         collection.add(
-        documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
-         )
+            documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
+        )
         return True
     except Exception as e:
         print(e)
@@ -188,16 +193,16 @@ def query_doc(
     user=Depends(get_current_user),
 ):
     try:
-        if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ:
-        # if you use docker use the model from the environment variable
+        if RAG_EMBEDDING_MODEL:
+            # if you use docker use the model from the environment variable
             collection = CHROMA_CLIENT.get_collection(
                 name=form_data.collection_name,
                 embedding_function=sentence_transformer_ef,
             )
         else:
-        # for local development use the default model
+            # for local development use the default model
             collection = CHROMA_CLIENT.get_collection(
-            name=form_data.collection_name,
+                name=form_data.collection_name,
             )
         result = collection.query(query_texts=[form_data.query], n_results=form_data.k)
         return result
@@ -269,18 +274,18 @@ def query_collection(
 
     for collection_name in form_data.collection_names:
         try:
-            if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ:
-            # if you use docker use the model from the environment variable
+            if RAG_EMBEDDING_MODEL:
+                # if you use docker use the model from the environment variable
                 collection = CHROMA_CLIENT.get_collection(
                     name=collection_name,
                     embedding_function=sentence_transformer_ef,
                 )
             else:
-            # for local development use the default model
+                # for local development use the default model
                 collection = CHROMA_CLIENT.get_collection(
-                name=collection_name,
+                    name=collection_name,
                 )
-                
+
             result = collection.query(
                 query_texts=[form_data.query], n_results=form_data.k
             )

+ 1 - 1
backend/config.py

@@ -137,7 +137,7 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
 
 CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
 # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2)
-SENTENCE_TRANSFORMER_EMBED_MODEL = os.getenv("DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL")
+RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "")
 CHROMA_CLIENT = chromadb.PersistentClient(
     path=CHROMA_DATA_PATH,
     settings=Settings(allow_reset=True, anonymized_telemetry=False),