milvus.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. from pymilvus import MilvusClient as Client
  2. from pymilvus import FieldSchema, DataType
  3. import json
  4. from typing import Optional
  5. from open_webui.apps.rag.vector.main import VectorItem, SearchResult, GetResult
  6. from open_webui.config import (
  7. MILVUS_URI,
  8. )
  9. class MilvusClient:
  10. def __init__(self):
  11. self.collection_prefix = "open_webui"
  12. self.client = Client(uri=MILVUS_URI)
  13. def _result_to_get_result(self, result) -> GetResult:
  14. print(result)
  15. ids = []
  16. documents = []
  17. metadatas = []
  18. for match in result:
  19. _ids = []
  20. _documents = []
  21. _metadatas = []
  22. for item in match:
  23. _ids.append(item.get("id"))
  24. _documents.append(item.get("data", {}).get("text"))
  25. _metadatas.append(item.get("metadata"))
  26. ids.append(_ids)
  27. documents.append(_documents)
  28. metadatas.append(_metadatas)
  29. return GetResult(
  30. **{
  31. "ids": ids,
  32. "documents": documents,
  33. "metadatas": metadatas,
  34. }
  35. )
  36. def _result_to_search_result(self, result) -> SearchResult:
  37. print(result)
  38. ids = []
  39. distances = []
  40. documents = []
  41. metadatas = []
  42. for match in result:
  43. _ids = []
  44. _distances = []
  45. _documents = []
  46. _metadatas = []
  47. for item in match:
  48. _ids.append(item.get("id"))
  49. _distances.append(item.get("distance"))
  50. _documents.append(item.get("entity", {}).get("data", {}).get("text"))
  51. _metadatas.append(item.get("entity", {}).get("metadata"))
  52. ids.append(_ids)
  53. distances.append(_distances)
  54. documents.append(_documents)
  55. metadatas.append(_metadatas)
  56. return SearchResult(
  57. **{
  58. "ids": ids,
  59. "distances": distances,
  60. "documents": documents,
  61. "metadatas": metadatas,
  62. }
  63. )
  64. def _create_collection(self, collection_name: str, dimension: int):
  65. schema = self.client.create_schema(
  66. auto_id=False,
  67. enable_dynamic_field=True,
  68. )
  69. schema.add_field(
  70. field_name="id",
  71. datatype=DataType.VARCHAR,
  72. is_primary=True,
  73. max_length=65535,
  74. )
  75. schema.add_field(
  76. field_name="vector",
  77. datatype=DataType.FLOAT_VECTOR,
  78. dim=dimension,
  79. description="vector",
  80. )
  81. schema.add_field(field_name="data", datatype=DataType.JSON, description="data")
  82. schema.add_field(
  83. field_name="metadata", datatype=DataType.JSON, description="metadata"
  84. )
  85. index_params = self.client.prepare_index_params()
  86. index_params.add_index(
  87. field_name="vector", index_type="HNSW", metric_type="COSINE", params={}
  88. )
  89. self.client.create_collection(
  90. collection_name=f"{self.collection_prefix}_{collection_name}",
  91. schema=schema,
  92. index_params=index_params,
  93. )
  94. def has_collection(self, collection_name: str) -> bool:
  95. # Check if the collection exists based on the collection name.
  96. return self.client.has_collection(
  97. collection_name=f"{self.collection_prefix}_{collection_name}"
  98. )
  99. def delete_collection(self, collection_name: str):
  100. # Delete the collection based on the collection name.
  101. return self.client.drop_collection(
  102. collection_name=f"{self.collection_prefix}_{collection_name}"
  103. )
  104. def search(
  105. self, collection_name: str, vectors: list[list[float | int]], limit: int
  106. ) -> Optional[SearchResult]:
  107. # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
  108. result = self.client.search(
  109. collection_name=f"{self.collection_prefix}_{collection_name}",
  110. data=vectors,
  111. limit=limit,
  112. output_fields=["data", "metadata"],
  113. )
  114. return self._result_to_search_result(result)
  115. def get(self, collection_name: str) -> Optional[GetResult]:
  116. # Get all the items in the collection.
  117. result = self.client.query(
  118. collection_name=f"{self.collection_prefix}_{collection_name}",
  119. filter='id != ""',
  120. )
  121. return self._result_to_get_result([result])
  122. def insert(self, collection_name: str, items: list[VectorItem]):
  123. # Insert the items into the collection, if the collection does not exist, it will be created.
  124. if not self.client.has_collection(
  125. collection_name=f"{self.collection_prefix}_{collection_name}"
  126. ):
  127. self._create_collection(
  128. collection_name=collection_name, dimension=len(items[0]["vector"])
  129. )
  130. return self.client.insert(
  131. collection_name=f"{self.collection_prefix}_{collection_name}",
  132. data=[
  133. {
  134. "id": item["id"],
  135. "vector": item["vector"],
  136. "data": {"text": item["text"]},
  137. "metadata": item["metadata"],
  138. }
  139. for item in items
  140. ],
  141. )
  142. def upsert(self, collection_name: str, items: list[VectorItem]):
  143. # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
  144. if not self.client.has_collection(
  145. collection_name=f"{self.collection_prefix}_{collection_name}"
  146. ):
  147. self._create_collection(
  148. collection_name=collection_name, dimension=len(items[0]["vector"])
  149. )
  150. return self.client.upsert(
  151. collection_name=f"{self.collection_prefix}_{collection_name}",
  152. data=[
  153. {
  154. "id": item["id"],
  155. "vector": item["vector"],
  156. "data": {"text": item["text"]},
  157. "metadata": item["metadata"],
  158. }
  159. for item in items
  160. ],
  161. )
  162. def delete(self, collection_name: str, ids: list[str]):
  163. # Delete the items from the collection based on the ids.
  164. return self.client.delete(
  165. collection_name=f"{self.collection_prefix}_{collection_name}",
  166. ids=ids,
  167. )
  168. def reset(self):
  169. # Resets the database. This will delete all collections and item entries.
  170. collection_names = self.client.list_collections()
  171. for collection_name in collection_names:
  172. if collection_name.startswith(self.collection_prefix):
  173. self.client.drop_collection(collection_name=collection_name)