Browse Source

Merge pull request #772 from jannikstdl/choose-embedding-model

feat: choose embedding model when using docker
Timothy Jaeryang Baek 1 year ago
parent
commit
c3916927bb
4 changed files with 87 additions and 17 deletions
  1. 19 4
      Dockerfile
  2. 1 1
      backend/apps/audio/main.py
  3. 61 11
      backend/apps/rag/main.py
  4. 6 1
      backend/config.py

+ 19 - 4
Dockerfile

@@ -30,10 +30,24 @@ ENV WEBUI_SECRET_KEY ""
 ENV SCARF_NO_ANALYTICS true
 ENV DO_NOT_TRACK true
 
-#Whisper TTS Settings
+######## Preloaded models ########
+# whisper TTS Settings
 ENV WHISPER_MODEL="base"
 ENV WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models"
 
+# RAG Embedding Model Settings
+# any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
+# Leaderboard: https://huggingface.co/spaces/mteb/leaderboard 
+# for better persormance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
+# IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
+ENV RAG_EMBEDDING_MODEL="all-MiniLM-L6-v2"
+# device type for whisper tts and ebbeding models - "cpu" (default), "cuda" (nvidia gpu and CUDA required) or "mps" (apple silicon) - choosing this right can lead to better performance
+ENV RAG_EMBEDDING_MODEL_DEVICE_TYPE="cpu"
+ENV RAG_EMBEDDING_MODEL_DIR="/app/backend/data/cache/embedding/models"
+ENV SENTENCE_TRANSFORMERS_HOME $RAG_EMBEDDING_MODEL_DIR
+
+######## Preloaded models ########
+
 WORKDIR /app/backend
 
 # install python dependencies
@@ -48,9 +62,10 @@ RUN apt-get update \
     && apt-get install -y pandoc netcat-openbsd \
     && rm -rf /var/lib/apt/lists/*
 
-# RUN python -c "from sentence_transformers import SentenceTransformer; model = SentenceTransformer('all-MiniLM-L6-v2')"
-RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"
-
+# preload embedding model
+RUN python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device=os.environ['RAG_EMBEDDING_MODEL_DEVICE_TYPE'])"
+# preload tts model
+RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='auto', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"
 
 # copy embedding weight from build
 RUN mkdir -p /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2

+ 1 - 1
backend/apps/audio/main.py

@@ -56,7 +56,7 @@ def transcribe(
 
         model = WhisperModel(
             WHISPER_MODEL,
-            device="cpu",
+            device="auto",
             compute_type="int8",
             download_root=WHISPER_MODEL_DIR,
         )

+ 61 - 11
backend/apps/rag/main.py

@@ -1,6 +1,5 @@
 from fastapi import (
     FastAPI,
-    Request,
     Depends,
     HTTPException,
     status,
@@ -14,7 +13,8 @@ import os, shutil
 from pathlib import Path
 from typing import List
 
-# from chromadb.utils import embedding_functions
+from sentence_transformers import SentenceTransformer
+from chromadb.utils import embedding_functions
 
 from langchain_community.document_loaders import (
     WebBaseLoader,
@@ -30,16 +30,12 @@ from langchain_community.document_loaders import (
     UnstructuredExcelLoader,
 )
 from langchain.text_splitter import RecursiveCharacterTextSplitter
-from langchain.chains import RetrievalQA
-from langchain_community.vectorstores import Chroma
-
 
 from pydantic import BaseModel
 from typing import Optional
 import mimetypes
 import uuid
 import json
-import time
 
 
 from apps.web.models.documents import (
@@ -58,23 +54,37 @@ from utils.utils import get_current_user, get_admin_user
 from config import (
     UPLOAD_DIR,
     DOCS_DIR,
-    EMBED_MODEL,
+    RAG_EMBEDDING_MODEL,
+    RAG_EMBEDDING_MODEL_DEVICE_TYPE,
     CHROMA_CLIENT,
     CHUNK_SIZE,
     CHUNK_OVERLAP,
     RAG_TEMPLATE,
 )
+
 from constants import ERROR_MESSAGES
 
-# EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction(
-#     model_name=EMBED_MODEL
-# )
+#
+# if RAG_EMBEDDING_MODEL:
+#    sentence_transformer_ef = SentenceTransformer(
+#        model_name_or_path=RAG_EMBEDDING_MODEL,
+#        cache_folder=RAG_EMBEDDING_MODEL_DIR,
+#        device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
+#    )
+
 
 app = FastAPI()
 
 app.state.CHUNK_SIZE = CHUNK_SIZE
 app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
 app.state.RAG_TEMPLATE = RAG_TEMPLATE
+app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
+app.state.sentence_transformer_ef = (
+    embedding_functions.SentenceTransformerEmbeddingFunction(
+        model_name=app.state.RAG_EMBEDDING_MODEL,
+        device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
+    )
+)
 
 
 origins = ["*"]
@@ -106,7 +116,10 @@ def store_data_in_vector_db(data, collection_name) -> bool:
     metadatas = [doc.metadata for doc in docs]
 
     try:
-        collection = CHROMA_CLIENT.create_collection(name=collection_name)
+        collection = CHROMA_CLIENT.create_collection(
+            name=collection_name,
+            embedding_function=app.state.sentence_transformer_ef,
+        )
 
         collection.add(
             documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
@@ -126,6 +139,38 @@ async def get_status():
         "status": True,
         "chunk_size": app.state.CHUNK_SIZE,
         "chunk_overlap": app.state.CHUNK_OVERLAP,
+        "template": app.state.RAG_TEMPLATE,
+        "embedding_model": app.state.RAG_EMBEDDING_MODEL,
+    }
+
+
+@app.get("/embedding/model")
+async def get_embedding_model(user=Depends(get_admin_user)):
+    return {
+        "status": True,
+        "embedding_model": app.state.RAG_EMBEDDING_MODEL,
+    }
+
+
+class EmbeddingModelUpdateForm(BaseModel):
+    embedding_model: str
+
+
+@app.post("/embedding/model/update")
+async def update_embedding_model(
+    form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
+):
+    app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
+    app.state.sentence_transformer_ef = (
+        embedding_functions.SentenceTransformerEmbeddingFunction(
+            model_name=app.state.RAG_EMBEDDING_MODEL,
+            device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
+        )
+    )
+
+    return {
+        "status": True,
+        "embedding_model": app.state.RAG_EMBEDDING_MODEL,
     }
 
 
@@ -190,8 +235,10 @@ def query_doc(
     user=Depends(get_current_user),
 ):
     try:
+        # if you use docker use the model from the environment variable
         collection = CHROMA_CLIENT.get_collection(
             name=form_data.collection_name,
+            embedding_function=app.state.sentence_transformer_ef,
         )
         result = collection.query(query_texts=[form_data.query], n_results=form_data.k)
         return result
@@ -263,9 +310,12 @@ def query_collection(
 
     for collection_name in form_data.collection_names:
         try:
+            # if you use docker use the model from the environment variable
             collection = CHROMA_CLIENT.get_collection(
                 name=collection_name,
+                embedding_function=app.state.sentence_transformer_ef,
             )
+
             result = collection.query(
                 query_texts=[form_data.query], n_results=form_data.k
             )

+ 6 - 1
backend/config.py

@@ -136,7 +136,12 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
 ####################################
 
 CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
-EMBED_MODEL = "all-MiniLM-L6-v2"
+# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2)
+RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2")
+# device type ebbeding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
+RAG_EMBEDDING_MODEL_DEVICE_TYPE = os.environ.get(
+    "RAG_EMBEDDING_MODEL_DEVICE_TYPE", "cpu"
+)
 CHROMA_CLIENT = chromadb.PersistentClient(
     path=CHROMA_DATA_PATH,
     settings=Settings(allow_reset=True, anonymized_telemetry=False),