Timothy J. Baek 7 месяцев назад
Родитель
Сommit
06debb322b
1 измененных файлов с 9 добавлено и 2 удалено
  1. 9 2
      backend/open_webui/apps/rag/main.py

+ 9 - 2
backend/open_webui/apps/rag/main.py

@@ -203,12 +203,19 @@ def update_reranking_model(
 
             class Colbert:
                 def __init__(self, name) -> None:
-                    self.ckpt = Checkpoint(name, colbert_config=ColBERTConfig())
+                    self.device = "cuda" if torch.cuda.is_available() else "cpu"
+                    self.ckpt = Checkpoint(name, colbert_config=ColBERTConfig()).to(
+                        self.device
+                    )
                     pass
 
                 def calculate_similarity_scores(
                     self, query_embeddings, document_embeddings
                 ):
+
+                    query_embeddings = query_embeddings.to(self.device)
+                    document_embeddings = document_embeddings.to(self.device)
+
                     # Validate dimensions to ensure compatibility
                     if query_embeddings.dim() != 3:
                         raise ValueError(
@@ -237,7 +244,7 @@ def update_reranking_model(
 
                     normalized_scores = torch.softmax(final_scores, dim=0)
 
-                    return normalized_scores.numpy().astype(np.float32)
+                    return normalized_scores.detach().cpu().numpy().astype(np.float32)
 
                 def predict(self, sentences):