|
@@ -10,6 +10,9 @@ from datetime import datetime
|
|
|
from pathlib import Path
|
|
|
from typing import Iterator, Optional, Sequence, Union
|
|
|
|
|
|
+
|
|
|
+import numpy as np
|
|
|
+import torch
|
|
|
import requests
|
|
|
import validators
|
|
|
|
|
@@ -114,6 +117,8 @@ from langchain_community.document_loaders import (
|
|
|
YoutubeLoader,
|
|
|
)
|
|
|
from langchain_core.documents import Document
|
|
|
+from colbert.infra import ColBERTConfig
|
|
|
+from colbert.modeling.checkpoint import Checkpoint
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
|
@@ -193,18 +198,76 @@ def update_reranking_model(
|
|
|
update_model: bool = False,
|
|
|
):
|
|
|
if reranking_model:
|
|
|
- import sentence_transformers
|
|
|
+ if reranking_model in ["jinaai/jina-colbert-v2"]:
|
|
|
|
|
|
- try:
|
|
|
- app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
|
|
|
- get_model_path(reranking_model, update_model),
|
|
|
- device=DEVICE_TYPE,
|
|
|
- trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
|
|
- )
|
|
|
- except:
|
|
|
- log.error("CrossEncoder error")
|
|
|
- app.state.sentence_transformer_rf = None
|
|
|
- app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
|
|
|
+ class Colbert:
|
|
|
+ def __init__(self, name) -> None:
|
|
|
+ self.ckpt = Checkpoint(name, colbert_config=ColBERTConfig())
|
|
|
+ pass
|
|
|
+
|
|
|
+ def calculate_similarity_scores(query_embeddings, document_embeddings):
|
|
|
+ # Validate dimensions to ensure compatibility
|
|
|
+ if query_embeddings.dim() != 3:
|
|
|
+ raise ValueError(
|
|
|
+ f"Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}."
|
|
|
+ )
|
|
|
+ if document_embeddings.dim() != 3:
|
|
|
+ raise ValueError(
|
|
|
+ f"Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}."
|
|
|
+ )
|
|
|
+ if query_embeddings.size(0) not in [1, document_embeddings.size(0)]:
|
|
|
+ raise ValueError(
|
|
|
+ "There should be either one query or queries equal to the number of documents."
|
|
|
+ )
|
|
|
+
|
|
|
+ # Transpose the query embeddings to align for matrix multiplication
|
|
|
+ transposed_query_embeddings = query_embeddings.permute(0, 2, 1)
|
|
|
+ # Compute similarity scores using batch matrix multiplication
|
|
|
+ computed_scores = torch.matmul(
|
|
|
+ document_embeddings, transposed_query_embeddings
|
|
|
+ )
|
|
|
+ # Apply max pooling to extract the highest semantic similarity across each document's sequence
|
|
|
+ maximum_scores = torch.max(computed_scores, dim=1).values
|
|
|
+
|
|
|
+ # Sum up the maximum scores across features to get the overall document relevance scores
|
|
|
+ final_scores = maximum_scores.sum(dim=1)
|
|
|
+
|
|
|
+ normalized_scores = torch.softmax(final_scores, dim=0)
|
|
|
+
|
|
|
+ return normalized_scores.numpy().astype(np.float32)
|
|
|
+
|
|
|
+ def predict(self, sentences):
|
|
|
+
|
|
|
+ query = sentences[0][0]
|
|
|
+ docs = [i[1] for i in sentences]
|
|
|
+
|
|
|
+ # Embedding the documents
|
|
|
+ embedded_docs = self.ckpt.docFromText(docs, bsize=32)[0]
|
|
|
+ # Embedding the queries
|
|
|
+ embedded_queries = self.ckpt.queryFromText([query], bsize=32)
|
|
|
+ embedded_query = embedded_queries[0]
|
|
|
+
|
|
|
+ # Calculate retrieval scores for the query against all documents
|
|
|
+ scores = self.calculate_similarity_scores(
|
|
|
+ embedded_query.unsqueeze(0), embedded_docs
|
|
|
+ )
|
|
|
+
|
|
|
+ return scores
|
|
|
+
|
|
|
+ app.state.sentence_transformer_rf = Colbert(reranking_model)
|
|
|
+ else:
|
|
|
+ import sentence_transformers
|
|
|
+
|
|
|
+ try:
|
|
|
+ app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
|
|
|
+ get_model_path(reranking_model, update_model),
|
|
|
+ device=DEVICE_TYPE,
|
|
|
+ trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
|
|
+ )
|
|
|
+ except:
|
|
|
+ log.error("CrossEncoder error")
|
|
|
+ app.state.sentence_transformer_rf = None
|
|
|
+ app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
|
|
|
else:
|
|
|
app.state.sentence_transformer_rf = None
|
|
|
|