Timothy J. Baek преди 6 месеца
родител
ревизия
05c15b017d
променени са 1 файла, в които са добавени 48 реда и са изтрити 34 реда
  1. 48 34
      backend/open_webui/apps/retrieval/vector/dbs/milvus.py

+ 48 - 34
backend/open_webui/apps/retrieval/vector/dbs/milvus.py

@@ -24,7 +24,6 @@ class MilvusClient:
             _ids = []
             _ids = []
             _documents = []
             _documents = []
             _metadatas = []
             _metadatas = []
-
             for item in match:
             for item in match:
                 _ids.append(item.get("id"))
                 _ids.append(item.get("id"))
                 _documents.append(item.get("data", {}).get("text"))
                 _documents.append(item.get("data", {}).get("text"))
@@ -112,12 +111,14 @@ class MilvusClient:
 
 
     def has_collection(self, collection_name: str) -> bool:
     def has_collection(self, collection_name: str) -> bool:
         # Check if the collection exists based on the collection name.
         # Check if the collection exists based on the collection name.
+        collection_name = collection_name.replace("-", "_")
         return self.client.has_collection(
         return self.client.has_collection(
             collection_name=f"{self.collection_prefix}_{collection_name}"
             collection_name=f"{self.collection_prefix}_{collection_name}"
         )
         )
 
 
     def delete_collection(self, collection_name: str):
     def delete_collection(self, collection_name: str):
         # Delete the collection based on the collection name.
         # Delete the collection based on the collection name.
+        collection_name = collection_name.replace("-", "_")
         return self.client.drop_collection(
         return self.client.drop_collection(
             collection_name=f"{self.collection_prefix}_{collection_name}"
             collection_name=f"{self.collection_prefix}_{collection_name}"
         )
         )
@@ -126,6 +127,7 @@ class MilvusClient:
         self, collection_name: str, vectors: list[list[float | int]], limit: int
         self, collection_name: str, vectors: list[list[float | int]], limit: int
     ) -> Optional[SearchResult]:
     ) -> Optional[SearchResult]:
         # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
         # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
+        collection_name = collection_name.replace("-", "_")
         result = self.client.search(
         result = self.client.search(
             collection_name=f"{self.collection_prefix}_{collection_name}",
             collection_name=f"{self.collection_prefix}_{collection_name}",
             data=vectors,
             data=vectors,
@@ -137,9 +139,13 @@ class MilvusClient:
 
 
     def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
     def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
         # Construct the filter string for querying
         # Construct the filter string for querying
+        collection_name = collection_name.replace("-", "_")
+        if not self.has_collection(collection_name):
+            return None
+
         filter_string = " && ".join(
         filter_string = " && ".join(
             [
             [
-                f"JSON_CONTAINS(metadata[{key}], '{[value] if isinstance(value, str) else value}')"
+                f'metadata["{key}"] == {json.dumps(value)}'
                 for key, value in filter.items()
                 for key, value in filter.items()
             ]
             ]
         )
         )
@@ -154,38 +160,45 @@ class MilvusClient:
         offset = 0
         offset = 0
         remaining = limit
         remaining = limit
 
 
-        # Loop until there are no more items to fetch or the desired limit is reached
-        while remaining > 0:
-            current_fetch = min(
-                max_limit, remaining
-            )  # Determine how many items to fetch in this iteration
-
-            results = self.client.query(
-                collection_name=f"{self.collection_prefix}_{collection_name}",
-                filter=filter_string,
-                output_fields=["*"],
-                limit=current_fetch,
-                offset=offset,
-            )
-
-            if not results:
-                break
-
-            all_results.extend(results)
-            results_count = len(results)
-            remaining -= (
-                results_count  # Decrease remaining by the number of items fetched
-            )
-            offset += results_count
-
-            # Break the loop if the results returned are less than the requested fetch count
-            if results_count < current_fetch:
-                break
-
-        return self._result_to_get_result(all_results)
+        try:
+            # Loop until there are no more items to fetch or the desired limit is reached
+            while remaining > 0:
+                print("remaining", remaining)
+                current_fetch = min(
+                    max_limit, remaining
+                )  # Determine how many items to fetch in this iteration
+
+                results = self.client.query(
+                    collection_name=f"{self.collection_prefix}_{collection_name}",
+                    filter=filter_string,
+                    output_fields=["*"],
+                    limit=current_fetch,
+                    offset=offset,
+                )
+
+                if not results:
+                    break
+
+                all_results.extend(results)
+                results_count = len(results)
+                remaining -= (
+                    results_count  # Decrease remaining by the number of items fetched
+                )
+                offset += results_count
+
+                # Break the loop if the results returned are less than the requested fetch count
+                if results_count < current_fetch:
+                    break
+
+            print(all_results)
+            return self._result_to_get_result([all_results])
+        except Exception as e:
+            print(e)
+            return None
 
 
     def get(self, collection_name: str) -> Optional[GetResult]:
     def get(self, collection_name: str) -> Optional[GetResult]:
         # Get all the items in the collection.
         # Get all the items in the collection.
+        collection_name = collection_name.replace("-", "_")
         result = self.client.query(
         result = self.client.query(
             collection_name=f"{self.collection_prefix}_{collection_name}",
             collection_name=f"{self.collection_prefix}_{collection_name}",
             filter='id != ""',
             filter='id != ""',
@@ -194,6 +207,7 @@ class MilvusClient:
 
 
     def insert(self, collection_name: str, items: list[VectorItem]):
     def insert(self, collection_name: str, items: list[VectorItem]):
         # Insert the items into the collection, if the collection does not exist, it will be created.
         # Insert the items into the collection, if the collection does not exist, it will be created.
+        collection_name = collection_name.replace("-", "_")
         if not self.client.has_collection(
         if not self.client.has_collection(
             collection_name=f"{self.collection_prefix}_{collection_name}"
             collection_name=f"{self.collection_prefix}_{collection_name}"
         ):
         ):
@@ -216,6 +230,7 @@ class MilvusClient:
 
 
     def upsert(self, collection_name: str, items: list[VectorItem]):
     def upsert(self, collection_name: str, items: list[VectorItem]):
         # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
         # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
+        collection_name = collection_name.replace("-", "_")
         if not self.client.has_collection(
         if not self.client.has_collection(
             collection_name=f"{self.collection_prefix}_{collection_name}"
             collection_name=f"{self.collection_prefix}_{collection_name}"
         ):
         ):
@@ -243,7 +258,7 @@ class MilvusClient:
         filter: Optional[dict] = None,
         filter: Optional[dict] = None,
     ):
     ):
         # Delete the items from the collection based on the ids.
         # Delete the items from the collection based on the ids.
-
+        collection_name = collection_name.replace("-", "_")
         if ids:
         if ids:
             return self.client.delete(
             return self.client.delete(
                 collection_name=f"{self.collection_prefix}_{collection_name}",
                 collection_name=f"{self.collection_prefix}_{collection_name}",
@@ -253,7 +268,7 @@ class MilvusClient:
             # Convert the filter dictionary to a string using JSON_CONTAINS.
             # Convert the filter dictionary to a string using JSON_CONTAINS.
             filter_string = " && ".join(
             filter_string = " && ".join(
                 [
                 [
-                    f"JSON_CONTAINS(metadata[{key}], '{[value] if isinstance(value, str) else value}')"
+                    f'metadata["{key}"] == {json.dumps(value)}'
                     for key, value in filter.items()
                     for key, value in filter.items()
                 ]
                 ]
             )
             )
@@ -265,7 +280,6 @@ class MilvusClient:
 
 
     def reset(self):
     def reset(self):
         # Resets the database. This will delete all collections and item entries.
         # Resets the database. This will delete all collections and item entries.
-
         collection_names = self.client.list_collections()
         collection_names = self.client.list_collections()
         for collection_name in collection_names:
         for collection_name in collection_names:
             if collection_name.startswith(self.collection_prefix):
             if collection_name.startswith(self.collection_prefix):