Преглед изворни кода

choose embedding model when using docker

Jannik Streidl пре 1 година
родитељ
комит
1846c1e80d
3 измењених фајлова са 46 додато и 20 уклоњено
  1. 10 2
      Dockerfile
  2. 34 17
      backend/apps/rag/main.py
  3. 2 1
      backend/config.py

+ 10 - 2
Dockerfile

@@ -30,10 +30,16 @@ ENV WEBUI_SECRET_KEY ""
 ENV SCARF_NO_ANALYTICS true
 ENV DO_NOT_TRACK true
 
-#Whisper TTS Settings
+# whisper TTS Settings
 ENV WHISPER_MODEL="base"
 ENV WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models"
 
+# any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
+# Leaderboard: https://huggingface.co/spaces/mteb/leaderboard 
+# for better persormance and multilangauge support use "intfloat/multilingual-e5-large"
+# IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
+ENV DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL="all-MiniLM-L6-v2"
+
 WORKDIR /app/backend
 
 # install python dependencies
@@ -48,7 +54,9 @@ RUN apt-get update \
     && apt-get install -y pandoc netcat-openbsd \
     && rm -rf /var/lib/apt/lists/*
 
-# RUN python -c "from sentence_transformers import SentenceTransformer; model = SentenceTransformer('all-MiniLM-L6-v2')"
+# preload embedding model
+RUN python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL'])"
+# preload tts model
 RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"
 
 

+ 34 - 17
backend/apps/rag/main.py

@@ -1,6 +1,5 @@
 from fastapi import (
     FastAPI,
-    Request,
     Depends,
     HTTPException,
     status,
@@ -12,7 +11,7 @@ from fastapi.middleware.cors import CORSMiddleware
 import os, shutil
 from typing import List
 
-# from chromadb.utils import embedding_functions
+from chromadb.utils import embedding_functions
 
 from langchain_community.document_loaders import (
     WebBaseLoader,
@@ -28,24 +27,19 @@ from langchain_community.document_loaders import (
     UnstructuredExcelLoader,
 )
 from langchain.text_splitter import RecursiveCharacterTextSplitter
-from langchain_community.vectorstores import Chroma
-from langchain.chains import RetrievalQA
 
 
 from pydantic import BaseModel
 from typing import Optional
 
 import uuid
-import time
 
 from utils.misc import calculate_sha256, calculate_sha256_string
 from utils.utils import get_current_user, get_admin_user
-from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
+from config import UPLOAD_DIR, SENTENCE_TRANSFORMER_EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
 from constants import ERROR_MESSAGES
 
-# EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction(
-#     model_name=EMBED_MODEL
-# )
+sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=SENTENCE_TRANSFORMER_EMBED_MODEL)
 
 app = FastAPI()
 
@@ -78,11 +72,17 @@ def store_data_in_vector_db(data, collection_name) -> bool:
     metadatas = [doc.metadata for doc in docs]
 
     try:
-        collection = CHROMA_CLIENT.create_collection(name=collection_name)
+        if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ:
+    # if you use docker use the model from the environment variable
+            collection = CHROMA_CLIENT.create_collection(name=collection_name, embedding_function=sentence_transformer_ef)
+        
+        else:
+    # for local development use the default model
+            collection = CHROMA_CLIENT.create_collection(name=collection_name)
 
         collection.add(
-            documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
-        )
+        documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
+         )
         return True
     except Exception as e:
         print(e)
@@ -109,9 +109,17 @@ def query_doc(
     user=Depends(get_current_user),
 ):
     try:
-        collection = CHROMA_CLIENT.get_collection(
-            name=form_data.collection_name,
-        )
+        if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ:
+        # if you use docker use the model from the environment variable
+            collection = CHROMA_CLIENT.get_collection(
+                name=form_data.collection_name,
+                embedding_function=sentence_transformer_ef
+            )
+        else:
+        # for local development use the default model
+                collection = CHROMA_CLIENT.get_collection(
+                name=form_data.collection_name,
+            )
         result = collection.query(query_texts=[form_data.query], n_results=form_data.k)
         return result
     except Exception as e:
@@ -182,9 +190,18 @@ def query_collection(
 
     for collection_name in form_data.collection_names:
         try:
-            collection = CHROMA_CLIENT.get_collection(
-                name=collection_name,
+            if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ:
+             # if you use docker use the model from the environment variable
+                collection = CHROMA_CLIENT.get_collection(
+                    name=form_data.collection_name,
+                    embedding_function=sentence_transformer_ef
+                )
+            else:
+            # for local development use the default model
+                collection = CHROMA_CLIENT.get_collection(
+                name=form_data.collection_name,
             )
+                
             result = collection.query(
                 query_texts=[form_data.query], n_results=form_data.k
             )

+ 2 - 1
backend/config.py

@@ -128,7 +128,8 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
 ####################################
 
 CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
-EMBED_MODEL = "all-MiniLM-L6-v2"
+# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2)
+SENTENCE_TRANSFORMER_EMBED_MODEL = os.getenv("DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL")
 CHROMA_CLIENT = chromadb.PersistentClient(
     path=CHROMA_DATA_PATH,
     settings=Settings(allow_reset=True, anonymized_telemetry=False),