milvus.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. from pymilvus import MilvusClient as Client
  2. from pymilvus import FieldSchema, DataType
  3. import json
  4. from typing import Optional
  5. from open_webui.apps.rag.vector.main import VectorItem, QueryResult
  6. from open_webui.config import (
  7. MILVUS_URI,
  8. )
  9. class MilvusClient:
  10. def __init__(self):
  11. self.collection_prefix = "open_webui"
  12. self.client = Client(uri=MILVUS_URI)
  13. def _result_to_query_result(self, result) -> QueryResult:
  14. print(result)
  15. ids = []
  16. distances = []
  17. documents = []
  18. metadatas = []
  19. for match in result:
  20. _ids = []
  21. _distances = []
  22. _documents = []
  23. _metadatas = []
  24. for item in match:
  25. _ids.append(item.get("id"))
  26. _distances.append(item.get("distance"))
  27. _documents.append(item.get("entity", {}).get("data", {}).get("text"))
  28. _metadatas.append(item.get("entity", {}).get("metadata"))
  29. ids.append(_ids)
  30. distances.append(_distances)
  31. documents.append(_documents)
  32. metadatas.append(_metadatas)
  33. return {
  34. "ids": ids,
  35. "distances": distances,
  36. "documents": documents,
  37. "metadatas": metadatas,
  38. }
  39. def _create_collection(self, collection_name: str, dimension: int):
  40. schema = self.client.create_schema(
  41. auto_id=False,
  42. enable_dynamic_field=True,
  43. )
  44. schema.add_field(
  45. field_name="id",
  46. datatype=DataType.VARCHAR,
  47. is_primary=True,
  48. max_length=65535,
  49. )
  50. schema.add_field(
  51. field_name="vector",
  52. datatype=DataType.FLOAT_VECTOR,
  53. dim=dimension,
  54. description="vector",
  55. )
  56. schema.add_field(field_name="data", datatype=DataType.JSON, description="data")
  57. schema.add_field(
  58. field_name="metadata", datatype=DataType.JSON, description="metadata"
  59. )
  60. index_params = self.client.prepare_index_params()
  61. index_params.add_index(
  62. field_name="vector", index_type="HNSW", metric_type="COSINE", params={}
  63. )
  64. self.client.create_collection(
  65. collection_name=f"{self.collection_prefix}_{collection_name}",
  66. schema=schema,
  67. index_params=index_params,
  68. )
  69. def list_collections(self) -> list[str]:
  70. # List all the collections in the database.
  71. return [
  72. collection[len(self.collection_prefix) :]
  73. for collection in self.client.list_collections()
  74. if collection.startswith(self.collection_prefix)
  75. ]
  76. def delete_collection(self, collection_name: str):
  77. # Delete the collection based on the collection name.
  78. return self.client.drop_collection(
  79. collection_name=f"{self.collection_prefix}_{collection_name}"
  80. )
  81. def search(
  82. self, collection_name: str, vectors: list[list[float | int]], limit: int
  83. ) -> Optional[QueryResult]:
  84. # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
  85. result = self.client.search(
  86. collection_name=f"{self.collection_prefix}_{collection_name}",
  87. data=vectors,
  88. limit=limit,
  89. output_fields=["data", "metadata"],
  90. )
  91. return self._result_to_query_result(result)
  92. def get(self, collection_name: str) -> Optional[QueryResult]:
  93. # Get all the items in the collection.
  94. result = self.client.query(
  95. collection_name=f"{self.collection_prefix}_{collection_name}",
  96. )
  97. return self._result_to_query_result(result)
  98. def insert(self, collection_name: str, items: list[VectorItem]):
  99. # Insert the items into the collection, if the collection does not exist, it will be created.
  100. if not self.client.has_collection(
  101. collection_name=f"{self.collection_prefix}_{collection_name}"
  102. ):
  103. self._create_collection(
  104. collection_name=collection_name, dimension=len(items[0]["vector"])
  105. )
  106. return self.client.insert(
  107. collection_name=f"{self.collection_prefix}_{collection_name}",
  108. data=[
  109. {
  110. "id": item["id"],
  111. "vector": item["vector"],
  112. "data": {"text": item["text"]},
  113. "metadata": item["metadata"],
  114. }
  115. for item in items
  116. ],
  117. )
  118. def upsert(self, collection_name: str, items: list[VectorItem]):
  119. # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
  120. if not self.client.has_collection(
  121. collection_name=f"{self.collection_prefix}_{collection_name}"
  122. ):
  123. self._create_collection(
  124. collection_name=collection_name, dimension=len(items[0]["vector"])
  125. )
  126. return self.client.upsert(
  127. collection_name=f"{self.collection_prefix}_{collection_name}",
  128. data=[
  129. {
  130. "id": item["id"],
  131. "vector": item["vector"],
  132. "data": {"text": item["text"]},
  133. "metadata": item["metadata"],
  134. }
  135. for item in items
  136. ],
  137. )
  138. def delete(self, collection_name: str, ids: list[str]):
  139. # Delete the items from the collection based on the ids.
  140. return self.client.delete(
  141. collection_name=f"{self.collection_prefix}_{collection_name}",
  142. ids=ids,
  143. )
  144. def reset(self):
  145. # Resets the database. This will delete all collections and item entries.
  146. collection_names = self.client.list_collections()
  147. for collection_name in collection_names:
  148. if collection_name.startswith(self.collection_prefix):
  149. self.client.drop_collection(collection_name=collection_name)