qdrant.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. from typing import Optional
  2. from qdrant_client import QdrantClient as Qclient
  3. from qdrant_client.http.models import PointStruct
  4. from qdrant_client.models import models
  5. from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
  6. from open_webui.config import QDRANT_URI, QDRANT_API_KEY
  7. NO_LIMIT = 999999999
  8. class QdrantClient:
  9. def __init__(self):
  10. self.collection_prefix = "open-webui"
  11. self.QDRANT_URI = QDRANT_URI
  12. self.QDRANT_API_KEY = QDRANT_API_KEY
  13. self.client = Qclient(
  14. url=self.QDRANT_URI,
  15. api_key=self.QDRANT_API_KEY
  16. ) if self.QDRANT_URI else None
  17. def _result_to_get_result(self, points) -> GetResult:
  18. ids = []
  19. documents = []
  20. metadatas = []
  21. for point in points:
  22. payload = point.payload
  23. ids.append(point.id)
  24. documents.append(payload["text"])
  25. metadatas.append(payload["metadata"])
  26. return GetResult(
  27. **{
  28. "ids": [ids],
  29. "documents": [documents],
  30. "metadatas": [metadatas],
  31. }
  32. )
  33. def _create_collection(self, collection_name: str, dimension: int):
  34. collection_name_with_prefix = f"{self.collection_prefix}_{collection_name}"
  35. self.client.create_collection(
  36. collection_name=collection_name_with_prefix,
  37. vectors_config=models.VectorParams(
  38. size=dimension, distance=models.Distance.COSINE
  39. ),
  40. )
  41. print(f"collection {collection_name_with_prefix} successfully created!")
  42. def _create_collection_if_not_exists(self, collection_name, dimension):
  43. if not self.has_collection(collection_name=collection_name):
  44. self._create_collection(
  45. collection_name=collection_name, dimension=dimension
  46. )
  47. def _create_points(self, items: list[VectorItem]):
  48. return [
  49. PointStruct(
  50. id=item["id"],
  51. vector=item["vector"],
  52. payload={"text": item["text"], "metadata": item["metadata"]},
  53. )
  54. for item in items
  55. ]
  56. def has_collection(self, collection_name: str) -> bool:
  57. return self.client.collection_exists(
  58. f"{self.collection_prefix}_{collection_name}"
  59. )
  60. def delete_collection(self, collection_name: str):
  61. return self.client.delete_collection(
  62. collection_name=f"{self.collection_prefix}_{collection_name}"
  63. )
  64. def search(
  65. self, collection_name: str, vectors: list[list[float | int]], limit: int
  66. ) -> Optional[SearchResult]:
  67. # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
  68. if limit is None:
  69. limit = NO_LIMIT # otherwise qdrant would set limit to 10!
  70. query_response = self.client.query_points(
  71. collection_name=f"{self.collection_prefix}_{collection_name}",
  72. query=vectors[0],
  73. limit=limit,
  74. )
  75. get_result = self._result_to_get_result(query_response.points)
  76. return SearchResult(
  77. ids=get_result.ids,
  78. documents=get_result.documents,
  79. metadatas=get_result.metadatas,
  80. distances=[[point.score for point in query_response.points]],
  81. )
  82. def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
  83. # Construct the filter string for querying
  84. if not self.has_collection(collection_name):
  85. return None
  86. try:
  87. if limit is None:
  88. limit = NO_LIMIT # otherwise qdrant would set limit to 10!
  89. field_conditions = []
  90. for key, value in filter.items():
  91. field_conditions.append(
  92. models.FieldCondition(
  93. key=f"metadata.{key}", match=models.MatchValue(value=value)
  94. )
  95. )
  96. points = self.client.query_points(
  97. collection_name=f"{self.collection_prefix}_{collection_name}",
  98. query_filter=models.Filter(should=field_conditions),
  99. limit=limit,
  100. )
  101. return self._result_to_get_result(points.points)
  102. except Exception as e:
  103. print(e)
  104. return None
  105. def get(self, collection_name: str) -> Optional[GetResult]:
  106. # Get all the items in the collection.
  107. points = self.client.query_points(
  108. collection_name=f"{self.collection_prefix}_{collection_name}",
  109. limit=NO_LIMIT, # otherwise qdrant would set limit to 10!
  110. )
  111. return self._result_to_get_result(points.points)
  112. def insert(self, collection_name: str, items: list[VectorItem]):
  113. # Insert the items into the collection, if the collection does not exist, it will be created.
  114. self._create_collection_if_not_exists(collection_name, len(items[0]["vector"]))
  115. points = self._create_points(items)
  116. self.client.upload_points(f"{self.collection_prefix}_{collection_name}", points)
  117. def upsert(self, collection_name: str, items: list[VectorItem]):
  118. # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
  119. self._create_collection_if_not_exists(collection_name, len(items[0]["vector"]))
  120. points = self._create_points(items)
  121. return self.client.upsert(f"{self.collection_prefix}_{collection_name}", points)
  122. def delete(
  123. self,
  124. collection_name: str,
  125. ids: Optional[list[str]] = None,
  126. filter: Optional[dict] = None,
  127. ):
  128. # Delete the items from the collection based on the ids.
  129. field_conditions = []
  130. if ids:
  131. for id_value in ids:
  132. field_conditions.append(
  133. models.FieldCondition(
  134. key="metadata.id",
  135. match=models.MatchValue(value=id_value),
  136. ),
  137. ),
  138. elif filter:
  139. for key, value in filter.items():
  140. field_conditions.append(
  141. models.FieldCondition(
  142. key=f"metadata.{key}",
  143. match=models.MatchValue(value=value),
  144. ),
  145. ),
  146. return self.client.delete(
  147. collection_name=f"{self.collection_prefix}_{collection_name}",
  148. points_selector=models.FilterSelector(
  149. filter=models.Filter(must=field_conditions)
  150. ),
  151. )
  152. def reset(self):
  153. # Resets the database. This will delete all collections and item entries.
  154. collection_names = self.client.get_collections().collections
  155. for collection_name in collection_names:
  156. if collection_name.name.startswith(self.collection_prefix):
  157. self.client.delete_collection(collection_name=collection_name.name)