chroma.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. import chromadb
  2. from chromadb import Settings
  3. from chromadb.utils.batch_utils import create_batches
  4. from typing import Optional
  5. from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
  6. from open_webui.config import (
  7. CHROMA_DATA_PATH,
  8. CHROMA_HTTP_HOST,
  9. CHROMA_HTTP_PORT,
  10. CHROMA_HTTP_HEADERS,
  11. CHROMA_HTTP_SSL,
  12. CHROMA_TENANT,
  13. CHROMA_DATABASE,
  14. CHROMA_CLIENT_AUTH_PROVIDER,
  15. CHROMA_CLIENT_AUTH_CREDENTIALS,
  16. )
  17. class ChromaClient:
  18. def __init__(self):
  19. settings_dict = {
  20. "allow_reset": True,
  21. "anonymized_telemetry": False,
  22. }
  23. if CHROMA_CLIENT_AUTH_PROVIDER is not None:
  24. settings_dict["chroma_client_auth_provider"] = CHROMA_CLIENT_AUTH_PROVIDER
  25. if CHROMA_CLIENT_AUTH_CREDENTIALS is not None:
  26. settings_dict["chroma_client_auth_credentials"] = (
  27. CHROMA_CLIENT_AUTH_CREDENTIALS
  28. )
  29. if CHROMA_HTTP_HOST != "":
  30. self.client = chromadb.HttpClient(
  31. host=CHROMA_HTTP_HOST,
  32. port=CHROMA_HTTP_PORT,
  33. headers=CHROMA_HTTP_HEADERS,
  34. ssl=CHROMA_HTTP_SSL,
  35. tenant=CHROMA_TENANT,
  36. database=CHROMA_DATABASE,
  37. settings=Settings(**settings_dict),
  38. )
  39. else:
  40. self.client = chromadb.PersistentClient(
  41. path=CHROMA_DATA_PATH,
  42. settings=Settings(**settings_dict),
  43. tenant=CHROMA_TENANT,
  44. database=CHROMA_DATABASE,
  45. )
  46. def has_collection(self, collection_name: str) -> bool:
  47. # Check if the collection exists based on the collection name.
  48. collections = self.client.list_collections()
  49. return collection_name in [collection.name for collection in collections]
  50. def delete_collection(self, collection_name: str):
  51. # Delete the collection based on the collection name.
  52. return self.client.delete_collection(name=collection_name)
  53. def search(
  54. self, collection_name: str, vectors: list[list[float | int]], limit: int
  55. ) -> Optional[SearchResult]:
  56. # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
  57. try:
  58. collection = self.client.get_collection(name=collection_name)
  59. if collection:
  60. result = collection.query(
  61. query_embeddings=vectors,
  62. n_results=limit,
  63. )
  64. return SearchResult(
  65. **{
  66. "ids": result["ids"],
  67. "distances": result["distances"],
  68. "documents": result["documents"],
  69. "metadatas": result["metadatas"],
  70. }
  71. )
  72. return None
  73. except Exception as e:
  74. return None
  75. def query(
  76. self, collection_name: str, filter: dict, limit: Optional[int] = None
  77. ) -> Optional[GetResult]:
  78. # Query the items from the collection based on the filter.
  79. try:
  80. collection = self.client.get_collection(name=collection_name)
  81. if collection:
  82. result = collection.get(
  83. where=filter,
  84. limit=limit,
  85. )
  86. return GetResult(
  87. **{
  88. "ids": [result["ids"]],
  89. "documents": [result["documents"]],
  90. "metadatas": [result["metadatas"]],
  91. }
  92. )
  93. return None
  94. except Exception as e:
  95. print(e)
  96. return None
  97. def get(self, collection_name: str) -> Optional[GetResult]:
  98. # Get all the items in the collection.
  99. collection = self.client.get_collection(name=collection_name)
  100. if collection:
  101. result = collection.get()
  102. return GetResult(
  103. **{
  104. "ids": [result["ids"]],
  105. "documents": [result["documents"]],
  106. "metadatas": [result["metadatas"]],
  107. }
  108. )
  109. return None
  110. def insert(self, collection_name: str, items: list[VectorItem]):
  111. # Insert the items into the collection, if the collection does not exist, it will be created.
  112. collection = self.client.get_or_create_collection(
  113. name=collection_name, metadata={"hnsw:space": "cosine"}
  114. )
  115. ids = [item["id"] for item in items]
  116. documents = [item["text"] for item in items]
  117. embeddings = [item["vector"] for item in items]
  118. metadatas = [item["metadata"] for item in items]
  119. for batch in create_batches(
  120. api=self.client,
  121. documents=documents,
  122. embeddings=embeddings,
  123. ids=ids,
  124. metadatas=metadatas,
  125. ):
  126. collection.add(*batch)
  127. def upsert(self, collection_name: str, items: list[VectorItem]):
  128. # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
  129. collection = self.client.get_or_create_collection(
  130. name=collection_name, metadata={"hnsw:space": "cosine"}
  131. )
  132. ids = [item["id"] for item in items]
  133. documents = [item["text"] for item in items]
  134. embeddings = [item["vector"] for item in items]
  135. metadatas = [item["metadata"] for item in items]
  136. collection.upsert(
  137. ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas
  138. )
  139. def delete(
  140. self,
  141. collection_name: str,
  142. ids: Optional[list[str]] = None,
  143. filter: Optional[dict] = None,
  144. ):
  145. # Delete the items from the collection based on the ids.
  146. collection = self.client.get_collection(name=collection_name)
  147. if collection:
  148. if ids:
  149. collection.delete(ids=ids)
  150. elif filter:
  151. collection.delete(where=filter)
  152. def reset(self):
  153. # Resets the database. This will delete all collections and item entries.
  154. return self.client.reset()