Prechádzať zdrojové kódy

enh: colbert rerank support

Timothy J. Baek 7 mesiacov pred
rodič
commit
b38986a0aa

+ 74 - 11
backend/open_webui/apps/rag/main.py

@@ -10,6 +10,9 @@ from datetime import datetime
 from pathlib import Path
 from typing import Iterator, Optional, Sequence, Union
 
+
+import numpy as np
+import torch
 import requests
 import validators
 
@@ -114,6 +117,8 @@ from langchain_community.document_loaders import (
     YoutubeLoader,
 )
 from langchain_core.documents import Document
+from colbert.infra import ColBERTConfig
+from colbert.modeling.checkpoint import Checkpoint
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
@@ -193,18 +198,76 @@ def update_reranking_model(
     update_model: bool = False,
 ):
     if reranking_model:
-        import sentence_transformers
+        if reranking_model in ["jinaai/jina-colbert-v2"]:
 
-        try:
-            app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
-                get_model_path(reranking_model, update_model),
-                device=DEVICE_TYPE,
-                trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
-            )
-        except:
-            log.error("CrossEncoder error")
-            app.state.sentence_transformer_rf = None
-            app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
+            class Colbert:
+                def __init__(self, name) -> None:
+                    self.ckpt = Checkpoint(name, colbert_config=ColBERTConfig())
+                    pass
+
+                def calculate_similarity_scores(query_embeddings, document_embeddings):
+                    # Validate dimensions to ensure compatibility
+                    if query_embeddings.dim() != 3:
+                        raise ValueError(
+                            f"Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}."
+                        )
+                    if document_embeddings.dim() != 3:
+                        raise ValueError(
+                            f"Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}."
+                        )
+                    if query_embeddings.size(0) not in [1, document_embeddings.size(0)]:
+                        raise ValueError(
+                            "There should be either one query or queries equal to the number of documents."
+                        )
+
+                    # Transpose the query embeddings to align for matrix multiplication
+                    transposed_query_embeddings = query_embeddings.permute(0, 2, 1)
+                    # Compute similarity scores using batch matrix multiplication
+                    computed_scores = torch.matmul(
+                        document_embeddings, transposed_query_embeddings
+                    )
+                    # Apply max pooling to extract the highest semantic similarity across each document's sequence
+                    maximum_scores = torch.max(computed_scores, dim=1).values
+
+                    # Sum up the maximum scores across features to get the overall document relevance scores
+                    final_scores = maximum_scores.sum(dim=1)
+
+                    normalized_scores = torch.softmax(final_scores, dim=0)
+
+                    return normalized_scores.numpy().astype(np.float32)
+
+                def predict(self, sentences):
+
+                    query = sentences[0][0]
+                    docs = [i[1] for i in sentences]
+
+                    # Embedding the documents
+                    embedded_docs = self.ckpt.docFromText(docs, bsize=32)[0]
+                    # Embedding the queries
+                    embedded_queries = self.ckpt.queryFromText([query], bsize=32)
+                    embedded_query = embedded_queries[0]
+
+                    # Calculate retrieval scores for the query against all documents
+                    scores = self.calculate_similarity_scores(
+                        embedded_query.unsqueeze(0), embedded_docs
+                    )
+
+                    return scores
+
+            app.state.sentence_transformer_rf = Colbert(reranking_model)
+        else:
+            import sentence_transformers
+
+            try:
+                app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
+                    get_model_path(reranking_model, update_model),
+                    device=DEVICE_TYPE,
+                    trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
+                )
+            except:
+                log.error("CrossEncoder error")
+                app.state.sentence_transformer_rf = None
+                app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
     else:
         app.state.sentence_transformer_rf = None
 

+ 1 - 2
backend/open_webui/apps/rag/utils.py

@@ -232,8 +232,7 @@ def query_collection_with_hybrid_search(
 
     if error:
         raise Exception(
-            "Hybrid search failed for all collections. Using "
-            "Non hybrid search as fallback."
+            "Hybrid search failed for all collections. Using Non hybrid search as fallback."
         )
 
     return merge_and_sort_query_results(results, k=k, reverse=True)