chroma.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. import chromadb
  2. import logging
  3. from chromadb import Settings
  4. from chromadb.utils.batch_utils import create_batches
  5. from typing import Optional
  6. from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
  7. from open_webui.config import (
  8. CHROMA_DATA_PATH,
  9. CHROMA_HTTP_HOST,
  10. CHROMA_HTTP_PORT,
  11. CHROMA_HTTP_HEADERS,
  12. CHROMA_HTTP_SSL,
  13. CHROMA_TENANT,
  14. CHROMA_DATABASE,
  15. CHROMA_CLIENT_AUTH_PROVIDER,
  16. CHROMA_CLIENT_AUTH_CREDENTIALS,
  17. )
  18. from open_webui.env import SRC_LOG_LEVELS
  19. log = logging.getLogger(__name__)
  20. log.setLevel(SRC_LOG_LEVELS["RAG"])
  21. class ChromaClient:
  22. def __init__(self):
  23. settings_dict = {
  24. "allow_reset": True,
  25. "anonymized_telemetry": False,
  26. }
  27. if CHROMA_CLIENT_AUTH_PROVIDER is not None:
  28. settings_dict["chroma_client_auth_provider"] = CHROMA_CLIENT_AUTH_PROVIDER
  29. if CHROMA_CLIENT_AUTH_CREDENTIALS is not None:
  30. settings_dict["chroma_client_auth_credentials"] = (
  31. CHROMA_CLIENT_AUTH_CREDENTIALS
  32. )
  33. if CHROMA_HTTP_HOST != "":
  34. self.client = chromadb.HttpClient(
  35. host=CHROMA_HTTP_HOST,
  36. port=CHROMA_HTTP_PORT,
  37. headers=CHROMA_HTTP_HEADERS,
  38. ssl=CHROMA_HTTP_SSL,
  39. tenant=CHROMA_TENANT,
  40. database=CHROMA_DATABASE,
  41. settings=Settings(**settings_dict),
  42. )
  43. else:
  44. self.client = chromadb.PersistentClient(
  45. path=CHROMA_DATA_PATH,
  46. settings=Settings(**settings_dict),
  47. tenant=CHROMA_TENANT,
  48. database=CHROMA_DATABASE,
  49. )
  50. def has_collection(self, collection_name: str) -> bool:
  51. # Check if the collection exists based on the collection name.
  52. collection_names = self.client.list_collections()
  53. return collection_name in collection_names
  54. def delete_collection(self, collection_name: str):
  55. # Delete the collection based on the collection name.
  56. return self.client.delete_collection(name=collection_name)
  57. def search(
  58. self, collection_name: str, vectors: list[list[float | int]], limit: int
  59. ) -> Optional[SearchResult]:
  60. # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
  61. try:
  62. collection = self.client.get_collection(name=collection_name)
  63. if collection:
  64. result = collection.query(
  65. query_embeddings=vectors,
  66. n_results=limit,
  67. )
  68. return SearchResult(
  69. **{
  70. "ids": result["ids"],
  71. "distances": result["distances"],
  72. "documents": result["documents"],
  73. "metadatas": result["metadatas"],
  74. }
  75. )
  76. return None
  77. except Exception as e:
  78. return None
  79. def query(
  80. self, collection_name: str, filter: dict, limit: Optional[int] = None
  81. ) -> Optional[GetResult]:
  82. # Query the items from the collection based on the filter.
  83. try:
  84. collection = self.client.get_collection(name=collection_name)
  85. if collection:
  86. result = collection.get(
  87. where=filter,
  88. limit=limit,
  89. )
  90. return GetResult(
  91. **{
  92. "ids": [result["ids"]],
  93. "documents": [result["documents"]],
  94. "metadatas": [result["metadatas"]],
  95. }
  96. )
  97. return None
  98. except:
  99. return None
  100. def get(self, collection_name: str) -> Optional[GetResult]:
  101. # Get all the items in the collection.
  102. collection = self.client.get_collection(name=collection_name)
  103. if collection:
  104. result = collection.get()
  105. return GetResult(
  106. **{
  107. "ids": [result["ids"]],
  108. "documents": [result["documents"]],
  109. "metadatas": [result["metadatas"]],
  110. }
  111. )
  112. return None
  113. def insert(self, collection_name: str, items: list[VectorItem]):
  114. # Insert the items into the collection, if the collection does not exist, it will be created.
  115. collection = self.client.get_or_create_collection(
  116. name=collection_name, metadata={"hnsw:space": "cosine"}
  117. )
  118. ids = [item["id"] for item in items]
  119. documents = [item["text"] for item in items]
  120. embeddings = [item["vector"] for item in items]
  121. metadatas = [item["metadata"] for item in items]
  122. for batch in create_batches(
  123. api=self.client,
  124. documents=documents,
  125. embeddings=embeddings,
  126. ids=ids,
  127. metadatas=metadatas,
  128. ):
  129. collection.add(*batch)
  130. def upsert(self, collection_name: str, items: list[VectorItem]):
  131. # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
  132. collection = self.client.get_or_create_collection(
  133. name=collection_name, metadata={"hnsw:space": "cosine"}
  134. )
  135. ids = [item["id"] for item in items]
  136. documents = [item["text"] for item in items]
  137. embeddings = [item["vector"] for item in items]
  138. metadatas = [item["metadata"] for item in items]
  139. collection.upsert(
  140. ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas
  141. )
  142. def delete(
  143. self,
  144. collection_name: str,
  145. ids: Optional[list[str]] = None,
  146. filter: Optional[dict] = None,
  147. ):
  148. # Delete the items from the collection based on the ids.
  149. collection = self.client.get_collection(name=collection_name)
  150. if collection:
  151. if ids:
  152. collection.delete(ids=ids)
  153. elif filter:
  154. collection.delete(where=filter)
  155. def reset(self):
  156. # Resets the database. This will delete all collections and item entries.
  157. return self.client.reset()