colbert.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import os
  2. import torch
  3. import numpy as np
  4. from colbert.infra import ColBERTConfig
  5. from colbert.modeling.checkpoint import Checkpoint
  6. class ColBERT:
  7. def __init__(self, name, **kwargs) -> None:
  8. print("ColBERT: Loading model", name)
  9. self.device = "cuda" if torch.cuda.is_available() else "cpu"
  10. DOCKER = kwargs.get("env") == "docker"
  11. if DOCKER:
  12. # This is a workaround for the issue with the docker container
  13. # where the torch extension is not loaded properly
  14. # and the following error is thrown:
  15. # /root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/segmented_maxsim_cpp.so: cannot open shared object file: No such file or directory
  16. lock_file = (
  17. "/root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/lock"
  18. )
  19. if os.path.exists(lock_file):
  20. os.remove(lock_file)
  21. self.ckpt = Checkpoint(
  22. name,
  23. colbert_config=ColBERTConfig(model_name=name),
  24. ).to(self.device)
  25. pass
  26. def calculate_similarity_scores(self, query_embeddings, document_embeddings):
  27. query_embeddings = query_embeddings.to(self.device)
  28. document_embeddings = document_embeddings.to(self.device)
  29. # Validate dimensions to ensure compatibility
  30. if query_embeddings.dim() != 3:
  31. raise ValueError(
  32. f"Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}."
  33. )
  34. if document_embeddings.dim() != 3:
  35. raise ValueError(
  36. f"Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}."
  37. )
  38. if query_embeddings.size(0) not in [1, document_embeddings.size(0)]:
  39. raise ValueError(
  40. "There should be either one query or queries equal to the number of documents."
  41. )
  42. # Transpose the query embeddings to align for matrix multiplication
  43. transposed_query_embeddings = query_embeddings.permute(0, 2, 1)
  44. # Compute similarity scores using batch matrix multiplication
  45. computed_scores = torch.matmul(document_embeddings, transposed_query_embeddings)
  46. # Apply max pooling to extract the highest semantic similarity across each document's sequence
  47. maximum_scores = torch.max(computed_scores, dim=1).values
  48. # Sum up the maximum scores across features to get the overall document relevance scores
  49. final_scores = maximum_scores.sum(dim=1)
  50. normalized_scores = torch.softmax(final_scores, dim=0)
  51. return normalized_scores.detach().cpu().numpy().astype(np.float32)
  52. def predict(self, sentences):
  53. query = sentences[0][0]
  54. docs = [i[1] for i in sentences]
  55. # Embedding the documents
  56. embedded_docs = self.ckpt.docFromText(docs, bsize=32)[0]
  57. # Embedding the queries
  58. embedded_queries = self.ckpt.queryFromText([query], bsize=32)
  59. embedded_query = embedded_queries[0]
  60. # Calculate retrieval scores for the query against all documents
  61. scores = self.calculate_similarity_scores(
  62. embedded_query.unsqueeze(0), embedded_docs
  63. )
  64. return scores