milvus.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. from pymilvus import MilvusClient as Client
  2. from pymilvus import FieldSchema, DataType
  3. import json
  4. from typing import Optional
  5. from open_webui.apps.rag.vector.main import VectorItem, QueryResult
  6. from open_webui.config import (
  7. MILVUS_URI,
  8. )
  9. class MilvusClient:
  10. def __init__(self):
  11. self.collection_prefix = "open_webui"
  12. self.client = Client(uri=MILVUS_URI)
  13. def _result_to_query_result(self, result) -> QueryResult:
  14. print(result)
  15. ids = []
  16. distances = []
  17. documents = []
  18. metadatas = []
  19. for match in result:
  20. _ids = []
  21. _distances = []
  22. _documents = []
  23. _metadatas = []
  24. for item in match:
  25. _ids.append(item.get("id"))
  26. _distances.append(item.get("distance"))
  27. _documents.append(item.get("entity", {}).get("data", {}).get("text"))
  28. _metadatas.append(item.get("entity", {}).get("metadata"))
  29. ids.append(_ids)
  30. distances.append(_distances)
  31. documents.append(_documents)
  32. metadatas.append(_metadatas)
  33. return {
  34. "ids": ids,
  35. "distances": distances,
  36. "documents": documents,
  37. "metadatas": metadatas,
  38. }
  39. def _create_collection(self, collection_name: str, dimension: int):
  40. schema = self.client.create_schema(
  41. auto_id=False,
  42. enable_dynamic_field=True,
  43. )
  44. schema.add_field(
  45. field_name="id",
  46. datatype=DataType.VARCHAR,
  47. is_primary=True,
  48. max_length=65535,
  49. )
  50. schema.add_field(
  51. field_name="vector",
  52. datatype=DataType.FLOAT_VECTOR,
  53. dim=dimension,
  54. description="vector",
  55. )
  56. schema.add_field(field_name="data", datatype=DataType.JSON, description="data")
  57. schema.add_field(
  58. field_name="metadata", datatype=DataType.JSON, description="metadata"
  59. )
  60. index_params = self.client.prepare_index_params()
  61. index_params.add_index(
  62. field_name="vector", index_type="HNSW", metric_type="COSINE", params={}
  63. )
  64. self.client.create_collection(
  65. collection_name=f"{self.collection_prefix}_{collection_name}",
  66. schema=schema,
  67. index_params=index_params,
  68. )
  69. def has_collection(self, collection_name: str) -> bool:
  70. # Check if the collection exists based on the collection name.
  71. return self.client.has_collection(
  72. collection_name=f"{self.collection_prefix}_{collection_name}"
  73. )
  74. def delete_collection(self, collection_name: str):
  75. # Delete the collection based on the collection name.
  76. return self.client.drop_collection(
  77. collection_name=f"{self.collection_prefix}_{collection_name}"
  78. )
  79. def search(
  80. self, collection_name: str, vectors: list[list[float | int]], limit: int
  81. ) -> Optional[QueryResult]:
  82. # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
  83. result = self.client.search(
  84. collection_name=f"{self.collection_prefix}_{collection_name}",
  85. data=vectors,
  86. limit=limit,
  87. output_fields=["data", "metadata"],
  88. )
  89. return self._result_to_query_result(result)
  90. def get(self, collection_name: str) -> Optional[QueryResult]:
  91. # Get all the items in the collection.
  92. result = self.client.query(
  93. collection_name=f"{self.collection_prefix}_{collection_name}",
  94. )
  95. return self._result_to_query_result(result)
  96. def insert(self, collection_name: str, items: list[VectorItem]):
  97. # Insert the items into the collection, if the collection does not exist, it will be created.
  98. if not self.client.has_collection(
  99. collection_name=f"{self.collection_prefix}_{collection_name}"
  100. ):
  101. self._create_collection(
  102. collection_name=collection_name, dimension=len(items[0]["vector"])
  103. )
  104. return self.client.insert(
  105. collection_name=f"{self.collection_prefix}_{collection_name}",
  106. data=[
  107. {
  108. "id": item["id"],
  109. "vector": item["vector"],
  110. "data": {"text": item["text"]},
  111. "metadata": item["metadata"],
  112. }
  113. for item in items
  114. ],
  115. )
  116. def upsert(self, collection_name: str, items: list[VectorItem]):
  117. # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
  118. if not self.client.has_collection(
  119. collection_name=f"{self.collection_prefix}_{collection_name}"
  120. ):
  121. self._create_collection(
  122. collection_name=collection_name, dimension=len(items[0]["vector"])
  123. )
  124. return self.client.upsert(
  125. collection_name=f"{self.collection_prefix}_{collection_name}",
  126. data=[
  127. {
  128. "id": item["id"],
  129. "vector": item["vector"],
  130. "data": {"text": item["text"]},
  131. "metadata": item["metadata"],
  132. }
  133. for item in items
  134. ],
  135. )
  136. def delete(self, collection_name: str, ids: list[str]):
  137. # Delete the items from the collection based on the ids.
  138. return self.client.delete(
  139. collection_name=f"{self.collection_prefix}_{collection_name}",
  140. ids=ids,
  141. )
  142. def reset(self):
  143. # Resets the database. This will delete all collections and item entries.
  144. collection_names = self.client.list_collections()
  145. for collection_name in collection_names:
  146. if collection_name.startswith(self.collection_prefix):
  147. self.client.drop_collection(collection_name=collection_name)