qdrant.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. import logging
  2. from typing import Optional
  3. from qdrant_client import QdrantClient as Qclient
  4. from qdrant_client.http.models import PointStruct
  5. from qdrant_client.models import models
  6. from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
  7. from open_webui.config import QDRANT_URI
  8. log = logging.getLogger(__name__)
  9. log.setLevel("INFO")
  10. class QdrantClient:
  11. def __init__(self):
  12. self.collection_prefix = "open-webui"
  13. self.QDRANT_URI = QDRANT_URI
  14. self.client = Qclient(url=self.QDRANT_URI) if self.QDRANT_URI else None
  15. def _result_to_get_result(self, points) -> GetResult:
  16. ids = []
  17. documents = []
  18. metadatas = []
  19. for point in points:
  20. payload = point.payload
  21. ids.append(point.id)
  22. documents.append(payload["text"])
  23. metadatas.append(payload["metadata"])
  24. return GetResult(
  25. **{
  26. "ids": [ids],
  27. "documents": [documents],
  28. "metadatas": [metadatas],
  29. }
  30. )
  31. def _create_collection(self, collection_name: str, dimension: int):
  32. collection_name_with_prefix = f"{self.collection_prefix}_{collection_name}"
  33. self.client.create_collection(
  34. collection_name=collection_name_with_prefix,
  35. vectors_config=models.VectorParams(size=dimension, distance=models.Distance.COSINE),
  36. )
  37. log.info(f"collection {collection_name_with_prefix} successfully created!")
  38. def _create_collection_if_not_exists(self, collection_name, dimension):
  39. if not self.has_collection(
  40. collection_name=collection_name
  41. ):
  42. self._create_collection(
  43. collection_name=collection_name, dimension=dimension
  44. )
  45. def has_collection(self, collection_name: str) -> bool:
  46. return self.client.collection_exists(f"{self.collection_prefix}_{collection_name}")
  47. def delete_collection(self, collection_name: str):
  48. return self.client.delete_collection(collection_name=f"{self.collection_prefix}_{collection_name}")
  49. def search(
  50. self, collection_name: str, vectors: list[list[float | int]], limit: int
  51. ) -> Optional[SearchResult]:
  52. # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
  53. log.info("start search...")
  54. query_response = self.client.query_points(
  55. collection_name=f"{self.collection_prefix}_{collection_name}",
  56. query=vectors[0],
  57. limit=limit,
  58. )
  59. get_result = self._result_to_get_result(query_response.points)
  60. return SearchResult(
  61. ids=get_result.ids,
  62. documents=get_result.documents,
  63. metadatas=get_result.metadatas,
  64. distances=[[point.score for point in query_response.points]]
  65. )
  66. def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
  67. # Construct the filter string for querying
  68. if not self.has_collection(collection_name):
  69. return None
  70. try:
  71. field_conditions = []
  72. for key, value in filter.items():
  73. field_conditions.append(
  74. models.FieldCondition(key=f"metadata.{key}", match=models.MatchValue(value=value)))
  75. log.info("start search...")
  76. points = self.client.query_points(
  77. collection_name=f"{self.collection_prefix}_{collection_name}",
  78. query_filter=models.Filter(should=field_conditions),
  79. limit=limit,
  80. )
  81. return self._result_to_get_result(points.points)
  82. except Exception as e:
  83. print(e)
  84. return None
  85. def get(self, collection_name: str) -> Optional[GetResult]:
  86. # Get all the items in the collection.
  87. points = self.client.query_points(
  88. collection_name=f"{self.collection_prefix}_{collection_name}",
  89. limit=10000000 # default is 10
  90. )
  91. return self._result_to_get_result(points.points)
  92. def insert(self, collection_name: str, items: list[VectorItem]):
  93. # Insert the items into the collection, if the collection does not exist, it will be created.
  94. self._create_collection_if_not_exists(collection_name, len(items[0]["vector"]))
  95. points = self.create_points(items)
  96. self.client.upload_points(f"{self.collection_prefix}_{collection_name}", points)
  97. def upsert(self, collection_name: str, items: list[VectorItem]):
  98. # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
  99. self._create_collection_if_not_exists(collection_name, len(items[0]["vector"]))
  100. points = self.create_points(items)
  101. return self.client.upsert(f"{self.collection_prefix}_{collection_name}", points)
  102. def delete(
  103. self,
  104. collection_name: str,
  105. ids: Optional[list[str]] = None,
  106. filter: Optional[dict] = None,
  107. ):
  108. # Delete the items from the collection based on the ids.
  109. field_conditions = []
  110. if ids:
  111. for id_value in ids:
  112. field_conditions.append(
  113. models.FieldCondition(
  114. key="metadata.id",
  115. match=models.MatchValue(value=id_value),
  116. ),
  117. ),
  118. elif filter:
  119. for key, value in filter.items():
  120. field_conditions.append(
  121. models.FieldCondition(
  122. key=f"metadata.{key}",
  123. match=models.MatchValue(value=value),
  124. ),
  125. ),
  126. return self.client.delete(
  127. collection_name=f"{self.collection_prefix}_{collection_name}",
  128. points_selector=models.FilterSelector(
  129. filter=models.Filter(
  130. must=field_conditions
  131. )
  132. ),
  133. )
  134. def reset(self):
  135. # Resets the database. This will delete all collections and item entries.
  136. collection_names = self.client.get_collections().collections
  137. for collection_name in collection_names:
  138. if collection_name.name.startswith(self.collection_prefix):
  139. self.client.delete_collection(collection_name=collection_name.name)
  140. def create_points(self, items: list[VectorItem]):
  141. vectors = [item["vector"] for item in items]
  142. log.info("insert points...")
  143. points = []
  144. for idx, item in enumerate(items):
  145. points.append(
  146. PointStruct(
  147. id=item["id"],
  148. vector=vectors[idx],
  149. payload={"text": item["text"], "metadata": item["metadata"]},
  150. )
  151. )
  152. return points