milvus.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. from pymilvus import MilvusClient as Client
  2. from pymilvus import FieldSchema, DataType
  3. import json
  4. from typing import Optional
  5. from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
  6. from open_webui.config import (
  7. MILVUS_URI,
  8. MILVUS_DB,
  9. MILVUS_TOKEN,
  10. )
  11. class MilvusClient:
  12. def __init__(self):
  13. self.collection_prefix = "open_webui"
  14. if MILVUS_TOKEN is None:
  15. self.client = Client(uri=MILVUS_URI, database=MILVUS_DB)
  16. else:
  17. self.client = Client(uri=MILVUS_URI, database=MILVUS_DB, token=MILVUS_TOKEN)
  18. def _result_to_get_result(self, result) -> GetResult:
  19. ids = []
  20. documents = []
  21. metadatas = []
  22. for match in result:
  23. _ids = []
  24. _documents = []
  25. _metadatas = []
  26. for item in match:
  27. _ids.append(item.get("id"))
  28. _documents.append(item.get("data", {}).get("text"))
  29. _metadatas.append(item.get("metadata"))
  30. ids.append(_ids)
  31. documents.append(_documents)
  32. metadatas.append(_metadatas)
  33. return GetResult(
  34. **{
  35. "ids": ids,
  36. "documents": documents,
  37. "metadatas": metadatas,
  38. }
  39. )
  40. def _result_to_search_result(self, result) -> SearchResult:
  41. ids = []
  42. distances = []
  43. documents = []
  44. metadatas = []
  45. for match in result:
  46. _ids = []
  47. _distances = []
  48. _documents = []
  49. _metadatas = []
  50. for item in match:
  51. _ids.append(item.get("id"))
  52. _distances.append(item.get("distance"))
  53. _documents.append(item.get("entity", {}).get("data", {}).get("text"))
  54. _metadatas.append(item.get("entity", {}).get("metadata"))
  55. ids.append(_ids)
  56. distances.append(_distances)
  57. documents.append(_documents)
  58. metadatas.append(_metadatas)
  59. return SearchResult(
  60. **{
  61. "ids": ids,
  62. "distances": distances,
  63. "documents": documents,
  64. "metadatas": metadatas,
  65. }
  66. )
  67. def _create_collection(self, collection_name: str, dimension: int):
  68. schema = self.client.create_schema(
  69. auto_id=False,
  70. enable_dynamic_field=True,
  71. )
  72. schema.add_field(
  73. field_name="id",
  74. datatype=DataType.VARCHAR,
  75. is_primary=True,
  76. max_length=65535,
  77. )
  78. schema.add_field(
  79. field_name="vector",
  80. datatype=DataType.FLOAT_VECTOR,
  81. dim=dimension,
  82. description="vector",
  83. )
  84. schema.add_field(field_name="data", datatype=DataType.JSON, description="data")
  85. schema.add_field(
  86. field_name="metadata", datatype=DataType.JSON, description="metadata"
  87. )
  88. index_params = self.client.prepare_index_params()
  89. index_params.add_index(
  90. field_name="vector",
  91. index_type="HNSW",
  92. metric_type="COSINE",
  93. params={"M": 16, "efConstruction": 100},
  94. )
  95. self.client.create_collection(
  96. collection_name=f"{self.collection_prefix}_{collection_name}",
  97. schema=schema,
  98. index_params=index_params,
  99. )
  100. def has_collection(self, collection_name: str) -> bool:
  101. # Check if the collection exists based on the collection name.
  102. collection_name = collection_name.replace("-", "_")
  103. return self.client.has_collection(
  104. collection_name=f"{self.collection_prefix}_{collection_name}"
  105. )
  106. def delete_collection(self, collection_name: str):
  107. # Delete the collection based on the collection name.
  108. collection_name = collection_name.replace("-", "_")
  109. return self.client.drop_collection(
  110. collection_name=f"{self.collection_prefix}_{collection_name}"
  111. )
  112. def search(
  113. self, collection_name: str, vectors: list[list[float | int]], limit: int
  114. ) -> Optional[SearchResult]:
  115. # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
  116. collection_name = collection_name.replace("-", "_")
  117. result = self.client.search(
  118. collection_name=f"{self.collection_prefix}_{collection_name}",
  119. data=vectors,
  120. limit=limit,
  121. output_fields=["data", "metadata"],
  122. )
  123. return self._result_to_search_result(result)
  124. def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
  125. # Construct the filter string for querying
  126. collection_name = collection_name.replace("-", "_")
  127. if not self.has_collection(collection_name):
  128. return None
  129. filter_string = " && ".join(
  130. [
  131. f'metadata["{key}"] == {json.dumps(value)}'
  132. for key, value in filter.items()
  133. ]
  134. )
  135. max_limit = 16383 # The maximum number of records per request
  136. all_results = []
  137. if limit is None:
  138. limit = float("inf") # Use infinity as a placeholder for no limit
  139. # Initialize offset and remaining to handle pagination
  140. offset = 0
  141. remaining = limit
  142. try:
  143. # Loop until there are no more items to fetch or the desired limit is reached
  144. while remaining > 0:
  145. print("remaining", remaining)
  146. current_fetch = min(
  147. max_limit, remaining
  148. ) # Determine how many items to fetch in this iteration
  149. results = self.client.query(
  150. collection_name=f"{self.collection_prefix}_{collection_name}",
  151. filter=filter_string,
  152. output_fields=["*"],
  153. limit=current_fetch,
  154. offset=offset,
  155. )
  156. if not results:
  157. break
  158. all_results.extend(results)
  159. results_count = len(results)
  160. remaining -= (
  161. results_count # Decrease remaining by the number of items fetched
  162. )
  163. offset += results_count
  164. # Break the loop if the results returned are less than the requested fetch count
  165. if results_count < current_fetch:
  166. break
  167. print(all_results)
  168. return self._result_to_get_result([all_results])
  169. except Exception as e:
  170. print(e)
  171. return None
  172. def get(self, collection_name: str) -> Optional[GetResult]:
  173. # Get all the items in the collection.
  174. collection_name = collection_name.replace("-", "_")
  175. result = self.client.query(
  176. collection_name=f"{self.collection_prefix}_{collection_name}",
  177. filter='id != ""',
  178. )
  179. return self._result_to_get_result([result])
  180. def insert(self, collection_name: str, items: list[VectorItem]):
  181. # Insert the items into the collection, if the collection does not exist, it will be created.
  182. collection_name = collection_name.replace("-", "_")
  183. if not self.client.has_collection(
  184. collection_name=f"{self.collection_prefix}_{collection_name}"
  185. ):
  186. self._create_collection(
  187. collection_name=collection_name, dimension=len(items[0]["vector"])
  188. )
  189. return self.client.insert(
  190. collection_name=f"{self.collection_prefix}_{collection_name}",
  191. data=[
  192. {
  193. "id": item["id"],
  194. "vector": item["vector"],
  195. "data": {"text": item["text"]},
  196. "metadata": item["metadata"],
  197. }
  198. for item in items
  199. ],
  200. )
  201. def upsert(self, collection_name: str, items: list[VectorItem]):
  202. # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
  203. collection_name = collection_name.replace("-", "_")
  204. if not self.client.has_collection(
  205. collection_name=f"{self.collection_prefix}_{collection_name}"
  206. ):
  207. self._create_collection(
  208. collection_name=collection_name, dimension=len(items[0]["vector"])
  209. )
  210. return self.client.upsert(
  211. collection_name=f"{self.collection_prefix}_{collection_name}",
  212. data=[
  213. {
  214. "id": item["id"],
  215. "vector": item["vector"],
  216. "data": {"text": item["text"]},
  217. "metadata": item["metadata"],
  218. }
  219. for item in items
  220. ],
  221. )
  222. def delete(
  223. self,
  224. collection_name: str,
  225. ids: Optional[list[str]] = None,
  226. filter: Optional[dict] = None,
  227. ):
  228. # Delete the items from the collection based on the ids.
  229. collection_name = collection_name.replace("-", "_")
  230. if ids:
  231. return self.client.delete(
  232. collection_name=f"{self.collection_prefix}_{collection_name}",
  233. ids=ids,
  234. )
  235. elif filter:
  236. # Convert the filter dictionary to a string using JSON_CONTAINS.
  237. filter_string = " && ".join(
  238. [
  239. f'metadata["{key}"] == {json.dumps(value)}'
  240. for key, value in filter.items()
  241. ]
  242. )
  243. return self.client.delete(
  244. collection_name=f"{self.collection_prefix}_{collection_name}",
  245. filter=filter_string,
  246. )
  247. def reset(self):
  248. # Resets the database. This will delete all collections and item entries.
  249. collection_names = self.client.list_collections()
  250. for collection_name in collection_names:
  251. if collection_name.startswith(self.collection_prefix):
  252. self.client.drop_collection(collection_name=collection_name)