milvus.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. from pymilvus import MilvusClient as Client
  2. from pymilvus import FieldSchema, DataType
  3. import json
  4. from typing import Optional
  5. from open_webui.apps.rag.vector.main import VectorItem, SearchResult, GetResult
  6. from open_webui.config import (
  7. MILVUS_URI,
  8. )
  9. class MilvusClient:
  10. def __init__(self):
  11. self.collection_prefix = "open_webui"
  12. self.client = Client(uri=MILVUS_URI)
  13. def _result_to_query_result(self, result) -> SearchResult:
  14. print(result)
  15. ids = []
  16. distances = []
  17. documents = []
  18. metadatas = []
  19. for match in result:
  20. _ids = []
  21. _distances = []
  22. _documents = []
  23. _metadatas = []
  24. for item in match:
  25. _ids.append(item.get("id"))
  26. _distances.append(item.get("distance"))
  27. _documents.append(item.get("entity", {}).get("data", {}).get("text"))
  28. _metadatas.append(item.get("entity", {}).get("metadata"))
  29. ids.append(_ids)
  30. distances.append(_distances)
  31. documents.append(_documents)
  32. metadatas.append(_metadatas)
  33. return SearchResult(
  34. **{
  35. "ids": ids,
  36. "distances": distances,
  37. "documents": documents,
  38. "metadatas": metadatas,
  39. }
  40. )
  41. def _create_collection(self, collection_name: str, dimension: int):
  42. schema = self.client.create_schema(
  43. auto_id=False,
  44. enable_dynamic_field=True,
  45. )
  46. schema.add_field(
  47. field_name="id",
  48. datatype=DataType.VARCHAR,
  49. is_primary=True,
  50. max_length=65535,
  51. )
  52. schema.add_field(
  53. field_name="vector",
  54. datatype=DataType.FLOAT_VECTOR,
  55. dim=dimension,
  56. description="vector",
  57. )
  58. schema.add_field(field_name="data", datatype=DataType.JSON, description="data")
  59. schema.add_field(
  60. field_name="metadata", datatype=DataType.JSON, description="metadata"
  61. )
  62. index_params = self.client.prepare_index_params()
  63. index_params.add_index(
  64. field_name="vector", index_type="HNSW", metric_type="COSINE", params={}
  65. )
  66. self.client.create_collection(
  67. collection_name=f"{self.collection_prefix}_{collection_name}",
  68. schema=schema,
  69. index_params=index_params,
  70. )
  71. def has_collection(self, collection_name: str) -> bool:
  72. # Check if the collection exists based on the collection name.
  73. return self.client.has_collection(
  74. collection_name=f"{self.collection_prefix}_{collection_name}"
  75. )
  76. def delete_collection(self, collection_name: str):
  77. # Delete the collection based on the collection name.
  78. return self.client.drop_collection(
  79. collection_name=f"{self.collection_prefix}_{collection_name}"
  80. )
  81. def search(
  82. self, collection_name: str, vectors: list[list[float | int]], limit: int
  83. ) -> Optional[SearchResult]:
  84. # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
  85. result = self.client.search(
  86. collection_name=f"{self.collection_prefix}_{collection_name}",
  87. data=vectors,
  88. limit=limit,
  89. output_fields=["data", "metadata"],
  90. )
  91. return self._result_to_query_result(result)
  92. def get(self, collection_name: str) -> Optional[GetResult]:
  93. # Get all the items in the collection.
  94. result = self.client.query(
  95. collection_name=f"{self.collection_prefix}_{collection_name}",
  96. filter='id != ""',
  97. )
  98. return self._result_to_query_result(result)
  99. def insert(self, collection_name: str, items: list[VectorItem]):
  100. # Insert the items into the collection, if the collection does not exist, it will be created.
  101. if not self.client.has_collection(
  102. collection_name=f"{self.collection_prefix}_{collection_name}"
  103. ):
  104. self._create_collection(
  105. collection_name=collection_name, dimension=len(items[0]["vector"])
  106. )
  107. return self.client.insert(
  108. collection_name=f"{self.collection_prefix}_{collection_name}",
  109. data=[
  110. {
  111. "id": item["id"],
  112. "vector": item["vector"],
  113. "data": {"text": item["text"]},
  114. "metadata": item["metadata"],
  115. }
  116. for item in items
  117. ],
  118. )
  119. def upsert(self, collection_name: str, items: list[VectorItem]):
  120. # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
  121. if not self.client.has_collection(
  122. collection_name=f"{self.collection_prefix}_{collection_name}"
  123. ):
  124. self._create_collection(
  125. collection_name=collection_name, dimension=len(items[0]["vector"])
  126. )
  127. return self.client.upsert(
  128. collection_name=f"{self.collection_prefix}_{collection_name}",
  129. data=[
  130. {
  131. "id": item["id"],
  132. "vector": item["vector"],
  133. "data": {"text": item["text"]},
  134. "metadata": item["metadata"],
  135. }
  136. for item in items
  137. ],
  138. )
  139. def delete(self, collection_name: str, ids: list[str]):
  140. # Delete the items from the collection based on the ids.
  141. return self.client.delete(
  142. collection_name=f"{self.collection_prefix}_{collection_name}",
  143. ids=ids,
  144. )
  145. def reset(self):
  146. # Resets the database. This will delete all collections and item entries.
  147. collection_names = self.client.list_collections()
  148. for collection_name in collection_names:
  149. if collection_name.startswith(self.collection_prefix):
  150. self.client.drop_collection(collection_name=collection_name)