浏览代码

refac: colbert cuda support

Timothy J. Baek 7 月之前
父节点
当前提交
06debb322b
共有 1 个文件被更改,包括 9 次插入2 次删除
  1. 9 2
      backend/open_webui/apps/rag/main.py

+ 9 - 2
backend/open_webui/apps/rag/main.py

@@ -203,12 +203,19 @@ def update_reranking_model(
 
 
             class Colbert:
             class Colbert:
                 def __init__(self, name) -> None:
                 def __init__(self, name) -> None:
-                    self.ckpt = Checkpoint(name, colbert_config=ColBERTConfig())
+                    self.device = "cuda" if torch.cuda.is_available() else "cpu"
+                    self.ckpt = Checkpoint(name, colbert_config=ColBERTConfig()).to(
+                        self.device
+                    )
                     pass
                     pass
 
 
                 def calculate_similarity_scores(
                 def calculate_similarity_scores(
                     self, query_embeddings, document_embeddings
                     self, query_embeddings, document_embeddings
                 ):
                 ):
+
+                    query_embeddings = query_embeddings.to(self.device)
+                    document_embeddings = document_embeddings.to(self.device)
+
                     # Validate dimensions to ensure compatibility
                     # Validate dimensions to ensure compatibility
                     if query_embeddings.dim() != 3:
                     if query_embeddings.dim() != 3:
                         raise ValueError(
                         raise ValueError(
@@ -237,7 +244,7 @@ def update_reranking_model(
 
 
                     normalized_scores = torch.softmax(final_scores, dim=0)
                     normalized_scores = torch.softmax(final_scores, dim=0)
 
 
-                    return normalized_scores.numpy().astype(np.float32)
+                    return normalized_scores.detach().cpu().numpy().astype(np.float32)
 
 
                 def predict(self, sentences):
                 def predict(self, sentences):