فهرست منبع

enh: vector db hash collision check

Timothy J. Baek 7 ماه پیش
والد
کامیت
2fc07fd6a2

+ 10 - 0
backend/open_webui/apps/retrieval/main.py

@@ -641,6 +641,16 @@ def save_docs_to_vector_db(
 ) -> bool:
 ) -> bool:
     log.info(f"save_docs_to_vector_db {docs} {collection_name}")
     log.info(f"save_docs_to_vector_db {docs} {collection_name}")
 
 
+    # Check if entries with the same hash (metadata.hash) already exist
+    if metadata and "hash" in metadata:
+        existing_docs = VECTOR_DB_CLIENT.query(
+            collection_name=collection_name,
+            filter={"hash": metadata["hash"]},
+        )
+        if existing_docs:
+            log.info(f"Document with hash {metadata['hash']} already exists")
+            return True
+
     if split:
     if split:
         text_splitter = RecursiveCharacterTextSplitter(
         text_splitter = RecursiveCharacterTextSplitter(
             chunk_size=app.state.config.CHUNK_SIZE,
             chunk_size=app.state.config.CHUNK_SIZE,

+ 21 - 0
backend/open_webui/apps/retrieval/vector/dbs/chroma.py

@@ -66,6 +66,27 @@ class ChromaClient:
             )
             )
         return None
         return None
 
 
+    def query(
+        self, collection_name: str, filter: dict, limit: int = 1
+    ) -> Optional[SearchResult]:
+        # Query the items from the collection based on the filter.
+        collection = self.client.get_collection(name=collection_name)
+        if collection:
+            result = collection.query(
+                where=filter,
+                n_results=limit,
+            )
+
+            return SearchResult(
+                **{
+                    "ids": result["ids"],
+                    "distances": result["distances"],
+                    "documents": result["documents"],
+                    "metadatas": result["metadatas"],
+                }
+            )
+        return None
+
     def get(self, collection_name: str) -> Optional[GetResult]:
     def get(self, collection_name: str) -> Optional[GetResult]:
         # Get all the items in the collection.
         # Get all the items in the collection.
         collection = self.client.get_collection(name=collection_name)
         collection = self.client.get_collection(name=collection_name)

+ 19 - 0
backend/open_webui/apps/retrieval/vector/dbs/milvus.py

@@ -135,6 +135,25 @@ class MilvusClient:
 
 
         return self._result_to_search_result(result)
         return self._result_to_search_result(result)
 
 
+    def query(
+        self, collection_name: str, filter: dict, limit: int = 1
+    ) -> Optional[SearchResult]:
+        # Query the items from the collection based on the filter.
+        filter_string = " && ".join(
+            [
+                f"JSON_CONTAINS(metadata[{key}], '{[value] if isinstance(value, str) else value}')"
+                for key, value in filter.items()
+            ]
+        )
+
+        result = self.client.query(
+            collection_name=f"{self.collection_prefix}_{collection_name}",
+            filter=filter_string,
+            limit=limit,
+        )
+
+        return self._result_to_search_result([result])
+
     def get(self, collection_name: str) -> Optional[GetResult]:
     def get(self, collection_name: str) -> Optional[GetResult]:
         # Get all the items in the collection.
         # Get all the items in the collection.
         result = self.client.query(
         result = self.client.query(