Timothy J. Baek 1 年間 前
コミット
ab104d5905
2 ファイル変更7 行追加5 行削除
  1. 3 2
      Dockerfile
  2. 4 3
      backend/config.py

+ 3 - 2
Dockerfile

@@ -41,9 +41,11 @@ ENV WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models"
 # for better persormance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
 # IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
 ENV RAG_EMBEDDING_MODEL="all-MiniLM-L6-v2"
-ENV SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models"
 # device type for whisper tts and ebbeding models - "cpu" (default), "cuda" (nvidia gpu and CUDA required) or "mps" (apple silicon) - choosing this right can lead to better performance
 ENV RAG_EMBEDDING_MODEL_DEVICE_TYPE="cpu"
+ENV RAG_EMBEDDING_MODEL_DIR="/app/backend/data/cache/embedding/models"
+ENV SENTENCE_TRANSFORMERS_HOME $RAG_EMBEDDING_MODEL_DIR
+
 ######## Preloaded models ########
 
 WORKDIR /app/backend
@@ -65,7 +67,6 @@ RUN python -c "import os; from chromadb.utils import embedding_functions; senten
 # preload tts model
 RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='auto', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"
 
-
 # copy embedding weight from build
 RUN mkdir -p /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2
 COPY --from=build /app/onnx /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx

+ 4 - 3
backend/config.py

@@ -137,10 +137,11 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
 
 CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
 # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2)
-RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "")
-
+RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2")
 # device type ebbeding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
-RAG_EMBEDDING_MODEL_DEVICE_TYPE = os.environ.get("RAG_EMBEDDING_MODEL_DEVICE_TYPE", "")
+RAG_EMBEDDING_MODEL_DEVICE_TYPE = os.environ.get(
+    "RAG_EMBEDDING_MODEL_DEVICE_TYPE", "cpu"
+)
 CHROMA_CLIENT = chromadb.PersistentClient(
     path=CHROMA_DATA_PATH,
     settings=Settings(allow_reset=True, anonymized_telemetry=False),