浏览代码

Merge pull request #1693 from buroa/buroa/hybrid-search

feat: hybrid search with reranking
Timothy Jaeryang Baek 1 年之前
父节点
当前提交
5ee2f1729a
共有 8 个文件被更改,包括 650 次插入171 次删除
  1. 4 0
      CHANGELOG.md
  2. 9 3
      Dockerfile
  3. 123 91
      backend/apps/rag/main.py
  4. 295 72
      backend/apps/rag/utils.py
  5. 24 4
      backend/config.py
  6. 2 0
      backend/main.py
  7. 62 0
      src/lib/apis/rag/index.ts
  8. 131 1
      src/lib/components/documents/Settings/General.svelte

+ 4 - 0
CHANGELOG.md

@@ -5,6 +5,10 @@ All notable changes to this project will be documented in this file.
 The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
 and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
 
+## [0.1.122] - 2024-04-24
+
+- **🌟 Enhanced RAG Pipeline**: Added hybrid searching with `BM25`, reranking using `CrossEncoder`, and relevance score thresholds.
+
 ## [0.1.121] - 2024-04-24
 
 ### Fixed

+ 9 - 3
Dockerfile

@@ -8,8 +8,9 @@ ARG USE_CUDA_VER=cu121
 # any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
 # Leaderboard: https://huggingface.co/spaces/mteb/leaderboard 
 # for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
-# IMPORTANT: If you change the default model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
+# IMPORTANT: If you change the embedding model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
 ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
+ARG USE_RERANKING_MODEL=""
 
 ######## WebUI frontend ########
 FROM --platform=$BUILDPLATFORM node:21-alpine3.19 as build
@@ -30,6 +31,7 @@ ARG USE_CUDA
 ARG USE_OLLAMA
 ARG USE_CUDA_VER
 ARG USE_EMBEDDING_MODEL
+ARG USE_RERANKING_MODEL
 
 ## Basis ##
 ENV ENV=prod \
@@ -38,7 +40,8 @@ ENV ENV=prod \
     USE_OLLAMA_DOCKER=${USE_OLLAMA} \
     USE_CUDA_DOCKER=${USE_CUDA} \
     USE_CUDA_DOCKER_VER=${USE_CUDA_VER} \
-    USE_EMBEDDING_MODEL_DOCKER=${USE_EMBEDDING_MODEL}
+    USE_EMBEDDING_MODEL_DOCKER=${USE_EMBEDDING_MODEL} \
+    USE_RERANKING_MODEL_DOCKER=${USE_RERANKING_MODEL}
 
 ## Basis URL Config ##
 ENV OLLAMA_BASE_URL="/ollama" \
@@ -62,8 +65,11 @@ ENV WHISPER_MODEL="base" \
 
 ## RAG Embedding model settings ##
 ENV RAG_EMBEDDING_MODEL="$USE_EMBEDDING_MODEL_DOCKER" \
-    RAG_EMBEDDING_MODEL_DIR="/app/backend/data/cache/embedding/models" \
+    RAG_RERANKING_MODEL="$USE_RERANKING_MODEL_DOCKER" \
     SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models"
+
+## Hugging Face download cache ##
+ENV HF_HOME="/app/backend/data/cache/embedding/models"
 #### Other models ##########################################################
 
 WORKDIR /app/backend

+ 123 - 91
backend/apps/rag/main.py

@@ -39,8 +39,6 @@ import json
 
 import sentence_transformers
 
-from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm
-
 from apps.web.models.documents import (
     Documents,
     DocumentForm,
@@ -48,9 +46,10 @@ from apps.web.models.documents import (
 )
 
 from apps.rag.utils import (
+    get_model_path,
     query_embeddings_doc,
+    query_embeddings_function,
     query_embeddings_collection,
-    generate_openai_embeddings,
 )
 
 from utils.misc import (
@@ -60,13 +59,20 @@ from utils.misc import (
     extract_folders_after_data_docs,
 )
 from utils.utils import get_current_user, get_admin_user
+
 from config import (
     SRC_LOG_LEVELS,
     UPLOAD_DIR,
     DOCS_DIR,
+    RAG_TOP_K,
+    RAG_RELEVANCE_THRESHOLD,
     RAG_EMBEDDING_ENGINE,
     RAG_EMBEDDING_MODEL,
+    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
     RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
+    RAG_RERANKING_MODEL,
+    RAG_RERANKING_MODEL_AUTO_UPDATE,
+    RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
     RAG_OPENAI_API_BASE_URL,
     RAG_OPENAI_API_KEY,
     DEVICE_TYPE,
@@ -83,14 +89,14 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
 
 app = FastAPI()
 
-
-app.state.TOP_K = 4
+app.state.TOP_K = RAG_TOP_K
+app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
 app.state.CHUNK_SIZE = CHUNK_SIZE
 app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
 
-
 app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
 app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
+app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
 app.state.RAG_TEMPLATE = RAG_TEMPLATE
 
 app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
@@ -98,16 +104,48 @@ app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
 
 app.state.PDF_EXTRACT_IMAGES = False
 
-if app.state.RAG_EMBEDDING_ENGINE == "":
-    app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
-        app.state.RAG_EMBEDDING_MODEL,
-        device=DEVICE_TYPE,
-        trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
-    )
 
+def update_embedding_model(
+    embedding_model: str,
+    update_model: bool = False,
+):
+    if embedding_model and app.state.RAG_EMBEDDING_ENGINE == "":
+        app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
+            get_model_path(embedding_model, update_model),
+            device=DEVICE_TYPE,
+            trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
+        )
+    else:
+        app.state.sentence_transformer_ef = None
+
+
+def update_reranking_model(
+    reranking_model: str,
+    update_model: bool = False,
+):
+    if reranking_model:
+        app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
+            get_model_path(reranking_model, update_model),
+            device=DEVICE_TYPE,
+            trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
+        )
+    else:
+        app.state.sentence_transformer_rf = None
+
+
+update_embedding_model(
+    app.state.RAG_EMBEDDING_MODEL,
+    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
+)
+
+update_reranking_model(
+    app.state.RAG_RERANKING_MODEL,
+    RAG_RERANKING_MODEL_AUTO_UPDATE,
+)
 
 origins = ["*"]
 
+
 app.add_middleware(
     CORSMiddleware,
     allow_origins=origins,
@@ -134,6 +172,7 @@ async def get_status():
         "template": app.state.RAG_TEMPLATE,
         "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
         "embedding_model": app.state.RAG_EMBEDDING_MODEL,
+        "reranking_model": app.state.RAG_RERANKING_MODEL,
     }
 
 
@@ -150,6 +189,11 @@ async def get_embedding_config(user=Depends(get_admin_user)):
     }
 
 
+@app.get("/reranking")
+async def get_reraanking_config(user=Depends(get_admin_user)):
+    return {"status": True, "reranking_model": app.state.RAG_RERANKING_MODEL}
+
+
 class OpenAIConfigForm(BaseModel):
     url: str
     key: str
@@ -170,22 +214,14 @@ async def update_embedding_config(
     )
     try:
         app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
+        app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
 
         if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
-            app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
-            app.state.sentence_transformer_ef = None
-
             if form_data.openai_config != None:
                 app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
                 app.state.OPENAI_API_KEY = form_data.openai_config.key
-        else:
-            sentence_transformer_ef = sentence_transformers.SentenceTransformer(
-                app.state.RAG_EMBEDDING_MODEL,
-                device=DEVICE_TYPE,
-                trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
-            )
-            app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
-            app.state.sentence_transformer_ef = sentence_transformer_ef
+
+        update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True)
 
         return {
             "status": True,
@@ -196,7 +232,6 @@ async def update_embedding_config(
                 "key": app.state.OPENAI_API_KEY,
             },
         }
-
     except Exception as e:
         log.exception(f"Problem updating embedding model: {e}")
         raise HTTPException(
@@ -205,6 +240,34 @@ async def update_embedding_config(
         )
 
 
+class RerankingModelUpdateForm(BaseModel):
+    reranking_model: str
+
+
+@app.post("/reranking/update")
+async def update_reranking_config(
+    form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
+):
+    log.info(
+        f"Updating reranking model: {app.state.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
+    )
+    try:
+        app.state.RAG_RERANKING_MODEL = form_data.reranking_model
+
+        update_reranking_model(app.state.RAG_RERANKING_MODEL, True)
+
+        return {
+            "status": True,
+            "reranking_model": app.state.RAG_RERANKING_MODEL,
+        }
+    except Exception as e:
+        log.exception(f"Problem updating reranking model: {e}")
+        raise HTTPException(
+            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+            detail=ERROR_MESSAGES.DEFAULT(e),
+        )
+
+
 @app.get("/config")
 async def get_rag_config(user=Depends(get_admin_user)):
     return {
@@ -257,11 +320,13 @@ async def get_query_settings(user=Depends(get_admin_user)):
         "status": True,
         "template": app.state.RAG_TEMPLATE,
         "k": app.state.TOP_K,
+        "r": app.state.RELEVANCE_THRESHOLD,
     }
 
 
 class QuerySettingsForm(BaseModel):
     k: Optional[int] = None
+    r: Optional[float] = None
     template: Optional[str] = None
 
 
@@ -271,6 +336,7 @@ async def update_query_settings(
 ):
     app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE
     app.state.TOP_K = form_data.k if form_data.k else 4
+    app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
     return {"status": True, "template": app.state.RAG_TEMPLATE}
 
 
@@ -278,6 +344,7 @@ class QueryDocForm(BaseModel):
     collection_name: str
     query: str
     k: Optional[int] = None
+    r: Optional[float] = None
 
 
 @app.post("/query/doc")
@@ -286,34 +353,22 @@ def query_doc_handler(
     user=Depends(get_current_user),
 ):
     try:
-        if app.state.RAG_EMBEDDING_ENGINE == "":
-            query_embeddings = app.state.sentence_transformer_ef.encode(
-                form_data.query
-            ).tolist()
-        elif app.state.RAG_EMBEDDING_ENGINE == "ollama":
-            query_embeddings = generate_ollama_embeddings(
-                GenerateEmbeddingsForm(
-                    **{
-                        "model": app.state.RAG_EMBEDDING_MODEL,
-                        "prompt": form_data.query,
-                    }
-                )
-            )
-        elif app.state.RAG_EMBEDDING_ENGINE == "openai":
-            query_embeddings = generate_openai_embeddings(
-                model=app.state.RAG_EMBEDDING_MODEL,
-                text=form_data.query,
-                key=app.state.OPENAI_API_KEY,
-                url=app.state.OPENAI_API_BASE_URL,
-            )
+        embeddings_function = query_embeddings_function(
+            app.state.RAG_EMBEDDING_ENGINE,
+            app.state.RAG_EMBEDDING_MODEL,
+            app.state.sentence_transformer_ef,
+            app.state.OPENAI_API_KEY,
+            app.state.OPENAI_API_BASE_URL,
+        )
 
         return query_embeddings_doc(
             collection_name=form_data.collection_name,
             query=form_data.query,
-            query_embeddings=query_embeddings,
             k=form_data.k if form_data.k else app.state.TOP_K,
+            r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
+            embeddings_function=embeddings_function,
+            reranking_function=app.state.sentence_transformer_rf,
         )
-
     except Exception as e:
         log.exception(e)
         raise HTTPException(
@@ -326,6 +381,7 @@ class QueryCollectionsForm(BaseModel):
     collection_names: List[str]
     query: str
     k: Optional[int] = None
+    r: Optional[float] = None
 
 
 @app.post("/query/collection")
@@ -334,33 +390,22 @@ def query_collection_handler(
     user=Depends(get_current_user),
 ):
     try:
-        if app.state.RAG_EMBEDDING_ENGINE == "":
-            query_embeddings = app.state.sentence_transformer_ef.encode(
-                form_data.query
-            ).tolist()
-        elif app.state.RAG_EMBEDDING_ENGINE == "ollama":
-            query_embeddings = generate_ollama_embeddings(
-                GenerateEmbeddingsForm(
-                    **{
-                        "model": app.state.RAG_EMBEDDING_MODEL,
-                        "prompt": form_data.query,
-                    }
-                )
-            )
-        elif app.state.RAG_EMBEDDING_ENGINE == "openai":
-            query_embeddings = generate_openai_embeddings(
-                model=app.state.RAG_EMBEDDING_MODEL,
-                text=form_data.query,
-                key=app.state.OPENAI_API_KEY,
-                url=app.state.OPENAI_API_BASE_URL,
-            )
+        embeddings_function = query_embeddings_function(
+            app.state.RAG_EMBEDDING_ENGINE,
+            app.state.RAG_EMBEDDING_MODEL,
+            app.state.sentence_transformer_ef,
+            app.state.OPENAI_API_KEY,
+            app.state.OPENAI_API_BASE_URL,
+        )
 
         return query_embeddings_collection(
             collection_names=form_data.collection_names,
-            query_embeddings=query_embeddings,
+            query=form_data.query,
             k=form_data.k if form_data.k else app.state.TOP_K,
+            r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
+            embeddings_function=embeddings_function,
+            reranking_function=app.state.sentence_transformer_rf,
         )
-
     except Exception as e:
         log.exception(e)
         raise HTTPException(
@@ -427,8 +472,6 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
     log.info(f"store_docs_in_vector_db {docs} {collection_name}")
 
     texts = [doc.page_content for doc in docs]
-    texts = list(map(lambda x: x.replace("\n", " "), texts))
-
     metadatas = [doc.metadata for doc in docs]
 
     try:
@@ -440,27 +483,16 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
 
         collection = CHROMA_CLIENT.create_collection(name=collection_name)
 
-        if app.state.RAG_EMBEDDING_ENGINE == "":
-            embeddings = app.state.sentence_transformer_ef.encode(texts).tolist()
-        elif app.state.RAG_EMBEDDING_ENGINE == "ollama":
-            embeddings = [
-                generate_ollama_embeddings(
-                    GenerateEmbeddingsForm(
-                        **{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text}
-                    )
-                )
-                for text in texts
-            ]
-        elif app.state.RAG_EMBEDDING_ENGINE == "openai":
-            embeddings = [
-                generate_openai_embeddings(
-                    model=app.state.RAG_EMBEDDING_MODEL,
-                    text=text,
-                    key=app.state.OPENAI_API_KEY,
-                    url=app.state.OPENAI_API_BASE_URL,
-                )
-                for text in texts
-            ]
+        embedding_func = query_embeddings_function(
+            app.state.RAG_EMBEDDING_ENGINE,
+            app.state.RAG_EMBEDDING_MODEL,
+            app.state.sentence_transformer_ef,
+            app.state.OPENAI_API_KEY,
+            app.state.OPENAI_API_BASE_URL,
+        )
+
+        embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
+        embeddings = embedding_func(embedding_texts)
 
         for batch in create_batches(
             api=CHROMA_CLIENT,

+ 295 - 72
backend/apps/rag/utils.py

@@ -1,3 +1,4 @@
+import os
 import logging
 import requests
 
@@ -8,6 +9,15 @@ from apps.ollama.main import (
     GenerateEmbeddingsForm,
 )
 
+from huggingface_hub import snapshot_download
+
+from langchain_core.documents import Document
+from langchain_community.retrievers import BM25Retriever
+from langchain.retrievers import (
+    ContextualCompressionRetriever,
+    EnsembleRetriever,
+)
+
 from config import SRC_LOG_LEVELS, CHROMA_CLIENT
 
 
@@ -15,18 +25,53 @@ log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 
 
-def query_embeddings_doc(collection_name: str, query: str, query_embeddings, k: int):
+def query_embeddings_doc(
+    collection_name: str,
+    query: str,
+    k: int,
+    r: float,
+    embeddings_function,
+    reranking_function,
+):
     try:
         # if you use docker use the model from the environment variable
-        log.info(f"query_embeddings_doc {query_embeddings}")
         collection = CHROMA_CLIENT.get_collection(name=collection_name)
 
-        result = collection.query(
-            query_embeddings=[query_embeddings],
-            n_results=k,
+        documents = collection.get()  # get all documents
+        bm25_retriever = BM25Retriever.from_texts(
+            texts=documents.get("documents"),
+            metadatas=documents.get("metadatas"),
+        )
+        bm25_retriever.k = k
+
+        chroma_retriever = ChromaRetriever(
+            collection=collection,
+            embeddings_function=embeddings_function,
+            top_n=k,
+        )
+
+        ensemble_retriever = EnsembleRetriever(
+            retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5]
+        )
+
+        compressor = RerankCompressor(
+            embeddings_function=embeddings_function,
+            reranking_function=reranking_function,
+            r_score=r,
+            top_n=k,
         )
 
-        log.info(f"query_embeddings_doc:result {result}")
+        compression_retriever = ContextualCompressionRetriever(
+            base_compressor=compressor, base_retriever=ensemble_retriever
+        )
+
+        result = compression_retriever.invoke(query)
+        result = {
+            "distances": [[d.metadata.get("score") for d in result]],
+            "documents": [[d.page_content for d in result]],
+            "metadatas": [[d.metadata for d in result]],
+        }
+
         return result
     except Exception as e:
         raise e
@@ -34,63 +79,65 @@ def query_embeddings_doc(collection_name: str, query: str, query_embeddings, k:
 
 def merge_and_sort_query_results(query_results, k):
     # Initialize lists to store combined data
-    combined_ids = []
     combined_distances = []
-    combined_metadatas = []
     combined_documents = []
+    combined_metadatas = []
 
-    # Combine data from each dictionary
     for data in query_results:
-        combined_ids.extend(data["ids"][0])
         combined_distances.extend(data["distances"][0])
-        combined_metadatas.extend(data["metadatas"][0])
         combined_documents.extend(data["documents"][0])
+        combined_metadatas.extend(data["metadatas"][0])
 
-    # Create a list of tuples (distance, id, metadata, document)
-    combined = list(
-        zip(combined_distances, combined_ids, combined_metadatas, combined_documents)
-    )
+    # Create a list of tuples (distance, document, metadata)
+    combined = list(zip(combined_distances, combined_documents, combined_metadatas))
 
     # Sort the list based on distances
     combined.sort(key=lambda x: x[0])
 
-    # Unzip the sorted list
-    sorted_distances, sorted_ids, sorted_metadatas, sorted_documents = zip(*combined)
+    # We don't have anything :-(
+    if not combined:
+        sorted_distances = []
+        sorted_documents = []
+        sorted_metadatas = []
+    else:
+        # Unzip the sorted list
+        sorted_distances, sorted_documents, sorted_metadatas = zip(*combined)
 
-    # Slicing the lists to include only k elements
-    sorted_distances = list(sorted_distances)[:k]
-    sorted_ids = list(sorted_ids)[:k]
-    sorted_metadatas = list(sorted_metadatas)[:k]
-    sorted_documents = list(sorted_documents)[:k]
+        # Slicing the lists to include only k elements
+        sorted_distances = list(sorted_distances)[:k]
+        sorted_documents = list(sorted_documents)[:k]
+        sorted_metadatas = list(sorted_metadatas)[:k]
 
     # Create the output dictionary
-    merged_query_results = {
-        "ids": [sorted_ids],
+    result = {
         "distances": [sorted_distances],
-        "metadatas": [sorted_metadatas],
         "documents": [sorted_documents],
-        "embeddings": None,
-        "uris": None,
-        "data": None,
+        "metadatas": [sorted_metadatas],
     }
 
-    return merged_query_results
+    return result
 
 
 def query_embeddings_collection(
-    collection_names: List[str], query: str, query_embeddings, k: int
+    collection_names: List[str],
+    query: str,
+    k: int,
+    r: float,
+    embeddings_function,
+    reranking_function,
 ):
 
     results = []
-    log.info(f"query_embeddings_collection {query_embeddings}")
 
     for collection_name in collection_names:
         try:
             result = query_embeddings_doc(
                 collection_name=collection_name,
                 query=query,
-                query_embeddings=query_embeddings,
                 k=k,
+                r=r,
+                embeddings_function=embeddings_function,
+                reranking_function=reranking_function,
             )
             results.append(result)
         except:
@@ -105,19 +152,57 @@ def rag_template(template: str, context: str, query: str):
     return template
 
 
+def query_embeddings_function(
+    embedding_engine,
+    embedding_model,
+    embedding_function,
+    openai_key,
+    openai_url,
+):
+    if embedding_engine == "":
+        return lambda query: embedding_function.encode(query).tolist()
+    elif embedding_engine in ["ollama", "openai"]:
+        if embedding_engine == "ollama":
+            func = lambda query: generate_ollama_embeddings(
+                GenerateEmbeddingsForm(
+                    **{
+                        "model": embedding_model,
+                        "prompt": query,
+                    }
+                )
+            )
+        elif embedding_engine == "openai":
+            func = lambda query: generate_openai_embeddings(
+                model=embedding_model,
+                text=query,
+                key=openai_key,
+                url=openai_url,
+            )
+
+        def generate_multiple(query, f):
+            if isinstance(query, list):
+                return [f(q) for q in query]
+            else:
+                return f(query)
+
+        return lambda query: generate_multiple(query, func)
+
+
 def rag_messages(
     docs,
     messages,
     template,
     k,
+    r,
     embedding_engine,
     embedding_model,
     embedding_function,
+    reranking_function,
     openai_key,
     openai_url,
 ):
     log.debug(
-        f"docs: {docs} {messages} {embedding_engine} {embedding_model} {embedding_function} {openai_key} {openai_url}"
+        f"docs: {docs} {messages} {embedding_engine} {embedding_model} {embedding_function} {reranking_function} {openai_key} {openai_url}"
     )
 
     last_user_message_idx = None
@@ -145,62 +230,66 @@ def rag_messages(
         content_type = None
         query = ""
 
+    embeddings_function = query_embeddings_function(
+        embedding_engine,
+        embedding_model,
+        embedding_function,
+        openai_key,
+        openai_url,
+    )
+
+    extracted_collections = []
     relevant_contexts = []
 
     for doc in docs:
         context = None
 
-        try:
+        collection = doc.get("collection_name")
+        if collection:
+            collection = [collection]
+        else:
+            collection = doc.get("collection_names", [])
+
+        collection = set(collection).difference(extracted_collections)
+        if not collection:
+            log.debug(f"skipping {doc} as it has already been extracted")
+            continue
 
+        try:
             if doc["type"] == "text":
                 context = doc["content"]
+            elif doc["type"] == "collection":
+                context = query_embeddings_collection(
+                    collection_names=doc["collection_names"],
+                    query=query,
+                    k=k,
+                    r=r,
+                    embeddings_function=embeddings_function,
+                    reranking_function=reranking_function,
+                )
             else:
-                if embedding_engine == "":
-                    query_embeddings = embedding_function.encode(query).tolist()
-                elif embedding_engine == "ollama":
-                    query_embeddings = generate_ollama_embeddings(
-                        GenerateEmbeddingsForm(
-                            **{
-                                "model": embedding_model,
-                                "prompt": query,
-                            }
-                        )
-                    )
-                elif embedding_engine == "openai":
-                    query_embeddings = generate_openai_embeddings(
-                        model=embedding_model,
-                        text=query,
-                        key=openai_key,
-                        url=openai_url,
-                    )
-
-                if doc["type"] == "collection":
-                    context = query_embeddings_collection(
-                        collection_names=doc["collection_names"],
-                        query=query,
-                        query_embeddings=query_embeddings,
-                        k=k,
-                    )
-                else:
-                    context = query_embeddings_doc(
-                        collection_name=doc["collection_name"],
-                        query=query,
-                        query_embeddings=query_embeddings,
-                        k=k,
-                    )
-
+                context = query_embeddings_doc(
+                    collection_name=doc["collection_name"],
+                    query=query,
+                    k=k,
+                    r=r,
+                    embeddings_function=embeddings_function,
+                    reranking_function=reranking_function,
+                )
         except Exception as e:
             log.exception(e)
             context = None
 
-        relevant_contexts.append(context)
+        if context:
+            relevant_contexts.append(context)
 
-    log.debug(f"relevant_contexts: {relevant_contexts}")
+        extracted_collections.extend(collection)
 
     context_string = ""
     for context in relevant_contexts:
-        if context:
-            context_string += " ".join(context["documents"][0]) + "\n"
+        items = context["documents"][0]
+        context_string += "\n\n".join(items)
+    context_string = context_string.strip()
 
     ra_content = rag_template(
         template=template,
@@ -208,6 +297,8 @@ def rag_messages(
         query=query,
     )
 
+    log.debug(f"ra_content: {ra_content}")
+
     if content_type == "list":
         new_content = []
         for content_item in user_message["content"]:
@@ -229,6 +320,44 @@ def rag_messages(
     return messages
 
 
+def get_model_path(model: str, update_model: bool = False):
+    # Construct huggingface_hub kwargs with local_files_only to return the snapshot path
+    cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
+
+    local_files_only = not update_model
+
+    snapshot_kwargs = {
+        "cache_dir": cache_dir,
+        "local_files_only": local_files_only,
+    }
+
+    log.debug(f"model: {model}")
+    log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
+
+    # Inspiration from upstream sentence_transformers
+    if (
+        os.path.exists(model)
+        or ("\\" in model or model.count("/") > 1)
+        and local_files_only
+    ):
+        # If fully qualified path exists, return input, else set repo_id
+        return model
+    elif "/" not in model:
+        # Set valid repo_id for model short-name
+        model = "sentence-transformers" + "/" + model
+
+    snapshot_kwargs["repo_id"] = model
+
+    # Attempt to query the huggingface_hub library to determine the local path and/or to update
+    try:
+        model_repo_path = snapshot_download(**snapshot_kwargs)
+        log.debug(f"model_repo_path: {model_repo_path}")
+        return model_repo_path
+    except Exception as e:
+        log.exception(f"Cannot determine model snapshot path: {e}")
+        return model
+
+
 def generate_openai_embeddings(
     model: str, text: str, key: str, url: str = "https://api.openai.com/v1"
 ):
@@ -250,3 +379,97 @@ def generate_openai_embeddings(
     except Exception as e:
         print(e)
         return None
+
+
+from typing import Any
+
+from langchain_core.retrievers import BaseRetriever
+from langchain_core.callbacks import CallbackManagerForRetrieverRun
+
+
+class ChromaRetriever(BaseRetriever):
+    collection: Any
+    embeddings_function: Any
+    top_n: int
+
+    def _get_relevant_documents(
+        self,
+        query: str,
+        *,
+        run_manager: CallbackManagerForRetrieverRun,
+    ) -> List[Document]:
+        query_embeddings = self.embeddings_function(query)
+
+        results = self.collection.query(
+            query_embeddings=[query_embeddings],
+            n_results=self.top_n,
+        )
+
+        ids = results["ids"][0]
+        metadatas = results["metadatas"][0]
+        documents = results["documents"][0]
+
+        return [
+            Document(
+                metadata=metadatas[idx],
+                page_content=documents[idx],
+            )
+            for idx in range(len(ids))
+        ]
+
+
+import operator
+
+from typing import Optional, Sequence
+
+from langchain_core.documents import BaseDocumentCompressor, Document
+from langchain_core.callbacks import Callbacks
+from langchain_core.pydantic_v1 import Extra
+
+from sentence_transformers import util
+
+
+class RerankCompressor(BaseDocumentCompressor):
+    embeddings_function: Any
+    reranking_function: Any
+    r_score: float
+    top_n: int
+
+    class Config:
+        extra = Extra.forbid
+        arbitrary_types_allowed = True
+
+    def compress_documents(
+        self,
+        documents: Sequence[Document],
+        query: str,
+        callbacks: Optional[Callbacks] = None,
+    ) -> Sequence[Document]:
+        if self.reranking_function:
+            scores = self.reranking_function.predict(
+                [(query, doc.page_content) for doc in documents]
+            )
+        else:
+            query_embedding = self.embeddings_function(query)
+            document_embedding = self.embeddings_function(
+                [doc.page_content for doc in documents]
+            )
+            scores = util.cos_sim(query_embedding, document_embedding)[0]
+
+        docs_with_scores = list(zip(documents, scores.tolist()))
+        if self.r_score:
+            docs_with_scores = [
+                (d, s) for d, s in docs_with_scores if s >= self.r_score
+            ]
+
+        result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
+        final_results = []
+        for doc, doc_score in result[: self.top_n]:
+            metadata = doc.metadata
+            metadata["score"] = doc_score
+            doc = Document(
+                page_content=doc.page_content,
+                metadata=metadata,
+            )
+            final_results.append(doc)
+        return final_results

+ 24 - 4
backend/config.py

@@ -420,6 +420,9 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
 CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
 # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
 
+RAG_TOP_K = int(os.environ.get("RAG_TOP_K", "5"))
+RAG_RELEVANCE_THRESHOLD = float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0"))
+
 RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "")
 
 RAG_EMBEDDING_MODEL = os.environ.get(
@@ -427,10 +430,26 @@ RAG_EMBEDDING_MODEL = os.environ.get(
 )
 log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"),
 
+RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
+    os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true"
+)
+
 RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
     os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
 )
 
+RAG_RERANKING_MODEL = os.environ.get("RAG_RERANKING_MODEL", "")
+if not RAG_RERANKING_MODEL == "":
+    log.info(f"Reranking model set: {RAG_RERANKING_MODEL}"),
+
+RAG_RERANKING_MODEL_AUTO_UPDATE = (
+    os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true"
+)
+
+RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
+    os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
+)
+
 # device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
 USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false")
 
@@ -439,16 +458,15 @@ if USE_CUDA.lower() == "true":
 else:
     DEVICE_TYPE = "cpu"
 
-
 CHROMA_CLIENT = chromadb.PersistentClient(
     path=CHROMA_DATA_PATH,
     settings=Settings(allow_reset=True, anonymized_telemetry=False),
 )
-CHUNK_SIZE = 1500
-CHUNK_OVERLAP = 100
 
+CHUNK_SIZE = int(os.environ.get("CHUNK_SIZE", "1500"))
+CHUNK_OVERLAP = int(os.environ.get("CHUNK_OVERLAP", "100"))
 
-RAG_TEMPLATE = """Use the following context as your learned knowledge, inside <context></context> XML tags.
+DEFAULT_RAG_TEMPLATE = """Use the following context as your learned knowledge, inside <context></context> XML tags.
 <context>
     [context]
 </context>
@@ -462,6 +480,8 @@ And answer according to the language of the user's question.
 Given the context information, answer the query.
 Query: [query]"""
 
+RAG_TEMPLATE = os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE)
+
 RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL)
 RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY)
 

+ 2 - 0
backend/main.py

@@ -120,9 +120,11 @@ class RAGMiddleware(BaseHTTPMiddleware):
                     data["messages"],
                     rag_app.state.RAG_TEMPLATE,
                     rag_app.state.TOP_K,
+                    rag_app.state.RELEVANCE_THRESHOLD,
                     rag_app.state.RAG_EMBEDDING_ENGINE,
                     rag_app.state.RAG_EMBEDDING_MODEL,
                     rag_app.state.sentence_transformer_ef,
+                    rag_app.state.sentence_transformer_rf,
                     rag_app.state.OPENAI_API_KEY,
                     rag_app.state.OPENAI_API_BASE_URL,
                 )

+ 62 - 0
src/lib/apis/rag/index.ts

@@ -123,6 +123,7 @@ export const getQuerySettings = async (token: string) => {
 
 type QuerySettings = {
 	k: number | null;
+	r: number | null;
 	template: string | null;
 };
 
@@ -413,3 +414,64 @@ export const updateEmbeddingConfig = async (token: string, payload: EmbeddingMod
 
 	return res;
 };
+
+export const getRerankingConfig = async (token: string) => {
+	let error = null;
+
+	const res = await fetch(`${RAG_API_BASE_URL}/reranking`, {
+		method: 'GET',
+		headers: {
+			'Content-Type': 'application/json',
+			Authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			error = err.detail;
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+type RerankingModelUpdateForm = {
+	reranking_model: string;
+};
+
+export const updateRerankingConfig = async (token: string, payload: RerankingModelUpdateForm) => {
+	let error = null;
+
+	const res = await fetch(`${RAG_API_BASE_URL}/reranking/update`, {
+		method: 'POST',
+		headers: {
+			'Content-Type': 'application/json',
+			Authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			...payload
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			error = err.detail;
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};

+ 131 - 1
src/lib/components/documents/Settings/General.svelte

@@ -8,7 +8,9 @@
 		updateQuerySettings,
 		resetVectorDB,
 		getEmbeddingConfig,
-		updateEmbeddingConfig
+		updateEmbeddingConfig,
+		getRerankingConfig,
+		updateRerankingConfig
 	} from '$lib/apis/rag';
 
 	import { documents, models } from '$lib/stores';
@@ -23,11 +25,13 @@
 
 	let scanDirLoading = false;
 	let updateEmbeddingModelLoading = false;
+	let updateRerankingModelLoading = false;
 
 	let showResetConfirm = false;
 
 	let embeddingEngine = '';
 	let embeddingModel = '';
+	let rerankingModel = '';
 
 	let OpenAIKey = '';
 	let OpenAIUrl = '';
@@ -38,6 +42,7 @@
 
 	let querySettings = {
 		template: '',
+		r: 0.0,
 		k: 4
 	};
 
@@ -115,6 +120,29 @@
 		}
 	};
 
+	const rerankingModelUpdateHandler = async () => {
+		console.log('Update reranking model attempt:', rerankingModel);
+
+		updateRerankingModelLoading = true;
+		const res = await updateRerankingConfig(localStorage.token, {
+			reranking_model: rerankingModel
+		}).catch(async (error) => {
+			toast.error(error);
+			await setRerankingConfig();
+			return null;
+		});
+		updateRerankingModelLoading = false;
+
+		if (res) {
+			console.log('rerankingModelUpdateHandler:', res);
+			if (res.status === true) {
+				toast.success($i18n.t('Reranking model set to "{{reranking_model}}"', res), {
+					duration: 1000 * 10
+				});
+			}
+		}
+	};
+
 	const submitHandler = async () => {
 		const res = await updateRAGConfig(localStorage.token, {
 			pdf_extract_images: pdfExtractImages,
@@ -138,6 +166,14 @@
 		}
 	};
 
+	const setRerankingConfig = async () => {
+		const rerankingConfig = await getRerankingConfig(localStorage.token);
+
+		if (rerankingConfig) {
+			rerankingModel = rerankingConfig.reranking_model;
+		}
+	};
+
 	onMount(async () => {
 		const res = await getRAGConfig(localStorage.token);
 
@@ -149,6 +185,7 @@
 		}
 
 		await setEmbeddingConfig();
+		await setRerankingConfig();
 
 		querySettings = await getQuerySettings(localStorage.token);
 	});
@@ -349,6 +386,79 @@
 
 				<hr class=" dark:border-gray-700 my-3" />
 
+				<div class=" ">
+					<div class=" mb-2 text-sm font-medium">{$i18n.t('Update Reranking Model')}</div>
+
+					<div class="flex w-full">
+						<div class="flex-1 mr-2">
+							<input
+								class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+								placeholder={$i18n.t('Update reranking model (e.g. {{model}})', {
+									model: rerankingModel.slice(-40)
+								})}
+								bind:value={rerankingModel}
+							/>
+						</div>
+						<button
+							class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
+							on:click={() => {
+								rerankingModelUpdateHandler();
+							}}
+							disabled={updateRerankingModelLoading}
+						>
+							{#if updateRerankingModelLoading}
+								<div class="self-center">
+									<svg
+										class=" w-4 h-4"
+										viewBox="0 0 24 24"
+										fill="currentColor"
+										xmlns="http://www.w3.org/2000/svg"
+										><style>
+											.spinner_ajPY {
+												transform-origin: center;
+												animation: spinner_AtaB 0.75s infinite linear;
+											}
+											@keyframes spinner_AtaB {
+												100% {
+													transform: rotate(360deg);
+												}
+											}
+										</style><path
+											d="M12,1A11,11,0,1,0,23,12,11,11,0,0,0,12,1Zm0,19a8,8,0,1,1,8-8A8,8,0,0,1,12,20Z"
+											opacity=".25"
+										/><path
+											d="M10.14,1.16a11,11,0,0,0-9,8.92A1.59,1.59,0,0,0,2.46,12,1.52,1.52,0,0,0,4.11,10.7a8,8,0,0,1,6.66-6.61A1.42,1.42,0,0,0,12,2.69h0A1.57,1.57,0,0,0,10.14,1.16Z"
+											class="spinner_ajPY"
+										/></svg
+									>
+								</div>
+							{:else}
+								<svg
+									xmlns="http://www.w3.org/2000/svg"
+									viewBox="0 0 16 16"
+									fill="currentColor"
+									class="w-4 h-4"
+								>
+									<path
+										d="M8.75 2.75a.75.75 0 0 0-1.5 0v5.69L5.03 6.22a.75.75 0 0 0-1.06 1.06l3.5 3.5a.75.75 0 0 0 1.06 0l3.5-3.5a.75.75 0 0 0-1.06-1.06L8.75 8.44V2.75Z"
+									/>
+									<path
+										d="M3.5 9.75a.75.75 0 0 0-1.5 0v1.5A2.75 2.75 0 0 0 4.75 14h6.5A2.75 2.75 0 0 0 14 11.25v-1.5a.75.75 0 0 0-1.5 0v1.5c0 .69-.56 1.25-1.25 1.25h-6.5c-.69 0-1.25-.56-1.25-1.25v-1.5Z"
+									/>
+								</svg>
+							{/if}
+						</button>
+					</div>
+				</div>
+
+				<div class="mt-2 mb-1 text-xs text-gray-400 dark:text-gray-500">
+					{$i18n.t(
+						'Note: If you choose a reranking model, it will use that to score and rerank instead of the embedding model.'
+					)}
+				</div>
+
+				<hr class=" dark:border-gray-700 my-3" />
+
 				<div class="  flex w-full justify-between">
 					<div class=" self-center text-xs font-medium">
 						{$i18n.t('Scan for documents from {{path}}', { path: '/data/docs' })}
@@ -473,6 +583,26 @@
 						</div>
 					</div>
 
+					<div class=" flex">
+						<div class="  flex w-full justify-between">
+							<div class="self-center text-xs font-medium flex-1">
+								{$i18n.t('Relevance Threshold')}
+							</div>
+
+							<div class="self-center p-3">
+								<input
+									class=" w-full rounded-lg py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+									type="number"
+									step="0.01"
+									placeholder={$i18n.t('Enter Relevance Threshold')}
+									bind:value={querySettings.r}
+									autocomplete="off"
+									min="0.0"
+								/>
+							</div>
+						</div>
+					</div>
+
 					<div>
 						<div class=" mb-2.5 text-sm font-medium">{$i18n.t('RAG Template')}</div>
 						<textarea