Browse Source

docker improvements & changed universal device type env for different models used

Jannik Streidl 1 year ago
parent
commit
1f6739337b
4 changed files with 35 additions and 18 deletions
  1. 24 11
      Dockerfile
  2. 6 2
      backend/apps/audio/main.py
  3. 3 3
      backend/apps/rag/main.py
  4. 2 2
      backend/config.py

+ 24 - 11
Dockerfile

@@ -1,4 +1,7 @@
 # syntax=docker/dockerfile:1
+# Initialize device type args
+ARG USE_CUDA=false
+ARG USE_MPS=false
 
 ######## WebUI frontend ########
 FROM node:21-alpine3.19 as build
@@ -23,6 +26,10 @@ RUN npm run build
 ######## WebUI backend ########
 FROM python:3.11-slim-bookworm as base
 
+# Use args
+ARG USE_CUDA
+ARG USE_MPS
+
 ## Basis ##
 ENV ENV=prod \
     PORT=8080
@@ -54,7 +61,8 @@ ENV RAG_EMBEDDING_MODEL="all-MiniLM-L6-v2" \
     # Important:
     #  If you want to use CUDA you need to install the nvidia-container-toolkit (https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) 
     #  you can set this to "cuda" but its recomended to use --build-arg CUDA_ENABLED=true flag when building the image
-    RAG_EMBEDDING_MODEL_DEVICE_TYPE="cpu"
+    RAG_EMBEDDING_MODEL_DEVICE_TYPE="cpu" \
+    DEVICE_COMPUTE_TYPE="int8"
 # device type for whisper tts and embbeding models - "cpu" (default), "cuda" (nvidia gpu and CUDA required) or "mps" (apple silicon) - choosing this right can lead to better performance
 #### Preloaded models ##########################################################
 
@@ -62,19 +70,24 @@ WORKDIR /app/backend
 # install python dependencies
 COPY ./backend/requirements.txt ./requirements.txt
 
-RUN pip3 install -r requirements.txt --no-cache-dir
-
-RUN if [ "$RAG_EMBEDDING_MODEL_DEVICE_TYPE" = "cuda" ]; then \
-        echo "CUDA enabled" && \
-        pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117 --no-cache-dir; \
+RUN if [ "$USE_CUDA" = "true" ]; then \
+        export DEVICE_TYPE="cuda" && \
+        pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117 --no-cache-dir && \
+        pip3 install -r requirements.txt --no-cache-dir; \
+    elif [ "$USE_MPS" = "true" ]; then \
+        export DEVICE_TYPE="mps" && \
+        pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \
+        pip3 install -r requirements.txt --no-cache-dir && \
+        python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" && \
+        python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device=os.environ['DEVICE_TYPE'])"; \
     else \
+        export DEVICE_TYPE="cpu" && \
         pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \
-        python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device=os.environ['RAG_EMBEDDING_MODEL_DEVICE_TYPE'])"; \
+        pip3 install -r requirements.txt --no-cache-dir && \
+        python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" && \
+        python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device=os.environ['DEVICE_TYPE'])"; \
     fi
 
-# preload tts model
-RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='auto', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"
-
 #  install required packages
 RUN apt-get update \
     # Install pandoc and netcat
@@ -100,4 +113,4 @@ COPY ./backend .
 
 EXPOSE 8080
 
-CMD [ "bash", "start.sh"]
+CMD [ "bash", "start.sh"]

+ 6 - 2
backend/apps/audio/main.py

@@ -21,7 +21,11 @@ from utils.utils import (
 )
 from utils.misc import calculate_sha256
 
-from config import CACHE_DIR, UPLOAD_DIR, WHISPER_MODEL, WHISPER_MODEL_DIR
+from config import CACHE_DIR, UPLOAD_DIR, WHISPER_MODEL, WHISPER_MODEL_DIR, DEVICE_TYPE
+
+if DEVICE_TYPE != "cuda":
+    whisper_device_type = "cpu"
+
 
 app = FastAPI()
 app.add_middleware(
@@ -56,7 +60,7 @@ def transcribe(
 
         model = WhisperModel(
             WHISPER_MODEL,
-            device="auto",
+            device=whisper_device_type,
             compute_type="int8",
             download_root=WHISPER_MODEL_DIR,
         )

+ 3 - 3
backend/apps/rag/main.py

@@ -57,7 +57,7 @@ from config import (
     UPLOAD_DIR,
     DOCS_DIR,
     RAG_EMBEDDING_MODEL,
-    RAG_EMBEDDING_MODEL_DEVICE_TYPE,
+    DEVICE_TYPE,
     CHROMA_CLIENT,
     CHUNK_SIZE,
     CHUNK_OVERLAP,
@@ -87,7 +87,7 @@ app.state.TOP_K = 4
 app.state.sentence_transformer_ef = (
     embedding_functions.SentenceTransformerEmbeddingFunction(
         model_name=app.state.RAG_EMBEDDING_MODEL,
-        device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
+        device=DEVICE_TYPE,
     )
 )
 
@@ -175,7 +175,7 @@ async def update_embedding_model(
     app.state.sentence_transformer_ef = (
         embedding_functions.SentenceTransformerEmbeddingFunction(
             model_name=app.state.RAG_EMBEDDING_MODEL,
-            device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
+            device=DEVICE_TYPE,
         )
     )
 

+ 2 - 2
backend/config.py

@@ -330,8 +330,8 @@ CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
 # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2)
 RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2")
 # device type ebbeding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
-RAG_EMBEDDING_MODEL_DEVICE_TYPE = os.environ.get(
-    "RAG_EMBEDDING_MODEL_DEVICE_TYPE", "cpu"
+DEVICE_TYPE = os.environ.get(
+    "DEVICE_TYPE", "cpu"
 )
 CHROMA_CLIENT = chromadb.PersistentClient(
     path=CHROMA_DATA_PATH,