|
@@ -203,12 +203,19 @@ def update_reranking_model(
|
|
|
|
|
|
class Colbert:
|
|
|
def __init__(self, name) -> None:
|
|
|
- self.ckpt = Checkpoint(name, colbert_config=ColBERTConfig())
|
|
|
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
+ self.ckpt = Checkpoint(name, colbert_config=ColBERTConfig()).to(
|
|
|
+ self.device
|
|
|
+ )
|
|
|
pass
|
|
|
|
|
|
def calculate_similarity_scores(
|
|
|
self, query_embeddings, document_embeddings
|
|
|
):
|
|
|
+
|
|
|
+ query_embeddings = query_embeddings.to(self.device)
|
|
|
+ document_embeddings = document_embeddings.to(self.device)
|
|
|
+
|
|
|
# Validate dimensions to ensure compatibility
|
|
|
if query_embeddings.dim() != 3:
|
|
|
raise ValueError(
|
|
@@ -237,7 +244,7 @@ def update_reranking_model(
|
|
|
|
|
|
normalized_scores = torch.softmax(final_scores, dim=0)
|
|
|
|
|
|
- return normalized_scores.numpy().astype(np.float32)
|
|
|
+ return normalized_scores.detach().cpu().numpy().astype(np.float32)
|
|
|
|
|
|
def predict(self, sentences):
|
|
|
|