123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081 |
- import os
- import torch
- import numpy as np
- from colbert.infra import ColBERTConfig
- from colbert.modeling.checkpoint import Checkpoint
- class ColBERT:
- def __init__(self, name, **kwargs) -> None:
- print("ColBERT: Loading model", name)
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
- DOCKER = kwargs.get("env") == "docker"
- if DOCKER:
- # This is a workaround for the issue with the docker container
- # where the torch extension is not loaded properly
- # and the following error is thrown:
- # /root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/segmented_maxsim_cpp.so: cannot open shared object file: No such file or directory
- lock_file = (
- "/root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/lock"
- )
- if os.path.exists(lock_file):
- os.remove(lock_file)
- self.ckpt = Checkpoint(
- name,
- colbert_config=ColBERTConfig(model_name=name),
- ).to(self.device)
- pass
- def calculate_similarity_scores(self, query_embeddings, document_embeddings):
- query_embeddings = query_embeddings.to(self.device)
- document_embeddings = document_embeddings.to(self.device)
- # Validate dimensions to ensure compatibility
- if query_embeddings.dim() != 3:
- raise ValueError(
- f"Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}."
- )
- if document_embeddings.dim() != 3:
- raise ValueError(
- f"Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}."
- )
- if query_embeddings.size(0) not in [1, document_embeddings.size(0)]:
- raise ValueError(
- "There should be either one query or queries equal to the number of documents."
- )
- # Transpose the query embeddings to align for matrix multiplication
- transposed_query_embeddings = query_embeddings.permute(0, 2, 1)
- # Compute similarity scores using batch matrix multiplication
- computed_scores = torch.matmul(document_embeddings, transposed_query_embeddings)
- # Apply max pooling to extract the highest semantic similarity across each document's sequence
- maximum_scores = torch.max(computed_scores, dim=1).values
- # Sum up the maximum scores across features to get the overall document relevance scores
- final_scores = maximum_scores.sum(dim=1)
- normalized_scores = torch.softmax(final_scores, dim=0)
- return normalized_scores.detach().cpu().numpy().astype(np.float32)
- def predict(self, sentences):
- query = sentences[0][0]
- docs = [i[1] for i in sentences]
- # Embedding the documents
- embedded_docs = self.ckpt.docFromText(docs, bsize=32)[0]
- # Embedding the queries
- embedded_queries = self.ckpt.queryFromText([query], bsize=32)
- embedded_query = embedded_queries[0]
- # Calculate retrieval scores for the query against all documents
- scores = self.calculate_similarity_scores(
- embedded_query.unsqueeze(0), embedded_docs
- )
- return scores
|