Sfoglia il codice sorgente

storing vectordb in project cache folder + device types

Jannik Streidl 1 anno fa
parent
commit
acf999013b
4 ha cambiato i file con 24 aggiunte e 5 eliminazioni
  1. 9 3
      Dockerfile
  2. 1 1
      backend/apps/audio/main.py
  3. 11 1
      backend/apps/rag/main.py
  4. 3 0
      backend/config.py

+ 9 - 3
Dockerfile

@@ -30,15 +30,21 @@ ENV WEBUI_SECRET_KEY ""
 ENV SCARF_NO_ANALYTICS true
 ENV DO_NOT_TRACK true
 
+######## Preloaded models ########
 # whisper TTS Settings
 ENV WHISPER_MODEL="base"
 ENV WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models"
 
+# RAG Embedding Model Settings
 # any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
 # Leaderboard: https://huggingface.co/spaces/mteb/leaderboard 
-# for better persormance and multilangauge support use "intfloat/multilingual-e5-large"
+# for better persormance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
 # IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
 ENV RAG_EMBEDDING_MODEL="all-MiniLM-L6-v2"
+ENV SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models"
+# device type for whisper tts and ebbeding models - "cpu" (default), "cuda" (nvidia gpu and CUDA required) or "mps" (apple silicon) - choosing this right can lead to better performance
+ENV RAG_EMBEDDING_MODEL_DEVICE_TYPE="cpu"
+######## Preloaded models ########
 
 WORKDIR /app/backend
 
@@ -55,9 +61,9 @@ RUN apt-get update \
     && rm -rf /var/lib/apt/lists/*
 
 # preload embedding model
-RUN python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'])"
+RUN python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device=os.environ['RAG_EMBEDDING_MODEL_DEVICE_TYPE'])"
 # preload tts model
-RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"
+RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='auto', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"
 
 
 # copy embedding weight from build

+ 1 - 1
backend/apps/audio/main.py

@@ -56,7 +56,7 @@ def transcribe(
 
         model = WhisperModel(
             WHISPER_MODEL,
-            device="cpu",
+            device="auto",
             compute_type="int8",
             download_root=WHISPER_MODEL_DIR,
         )

+ 11 - 1
backend/apps/rag/main.py

@@ -13,6 +13,7 @@ import os, shutil
 from pathlib import Path
 from typing import List
 
+from sentence_transformers import SentenceTransformer
 from chromadb.utils import embedding_functions
 
 from langchain_community.document_loaders import (
@@ -52,6 +53,7 @@ from config import (
     UPLOAD_DIR,
     DOCS_DIR,
     RAG_EMBEDDING_MODEL,
+    RAG_EMBEDDING_MODEL_DEVICE_TYPE,
     CHROMA_CLIENT,
     CHUNK_SIZE,
     CHUNK_OVERLAP,
@@ -60,10 +62,18 @@ from config import (
 
 from constants import ERROR_MESSAGES
 
+#
+#if RAG_EMBEDDING_MODEL:
+#    sentence_transformer_ef = SentenceTransformer(
+#        model_name_or_path=RAG_EMBEDDING_MODEL,
+#        cache_folder=RAG_EMBEDDING_MODEL_DIR,
+#        device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
+#    )
 
 if RAG_EMBEDDING_MODEL:
     sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
-        model_name=RAG_EMBEDDING_MODEL
+        model_name=RAG_EMBEDDING_MODEL,
+        device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
     )
 
 app = FastAPI()

+ 3 - 0
backend/config.py

@@ -138,6 +138,9 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
 CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
 # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2)
 RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "")
+
+# device type ebbeding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
+RAG_EMBEDDING_MODEL_DEVICE_TYPE = os.environ.get("RAG_EMBEDDING_MODEL_DEVICE_TYPE", "")
 CHROMA_CLIENT = chromadb.PersistentClient(
     path=CHROMA_DATA_PATH,
     settings=Settings(allow_reset=True, anonymized_telemetry=False),