chroma.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. import chromadb
  2. from chromadb import Settings
  3. from chromadb.utils.batch_utils import create_batches
  4. from typing import Optional
  5. from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
  6. from open_webui.config import (
  7. CHROMA_DATA_PATH,
  8. CHROMA_HTTP_HOST,
  9. CHROMA_HTTP_PORT,
  10. CHROMA_HTTP_HEADERS,
  11. CHROMA_HTTP_SSL,
  12. CHROMA_TENANT,
  13. CHROMA_DATABASE,
  14. )
  15. class ChromaClient:
  16. def __init__(self):
  17. if CHROMA_HTTP_HOST != "":
  18. self.client = chromadb.HttpClient(
  19. host=CHROMA_HTTP_HOST,
  20. port=CHROMA_HTTP_PORT,
  21. headers=CHROMA_HTTP_HEADERS,
  22. ssl=CHROMA_HTTP_SSL,
  23. tenant=CHROMA_TENANT,
  24. database=CHROMA_DATABASE,
  25. settings=Settings(allow_reset=True, anonymized_telemetry=False),
  26. )
  27. else:
  28. self.client = chromadb.PersistentClient(
  29. path=CHROMA_DATA_PATH,
  30. settings=Settings(allow_reset=True, anonymized_telemetry=False),
  31. tenant=CHROMA_TENANT,
  32. database=CHROMA_DATABASE,
  33. )
  34. def has_collection(self, collection_name: str) -> bool:
  35. # Check if the collection exists based on the collection name.
  36. collections = self.client.list_collections()
  37. return collection_name in [collection.name for collection in collections]
  38. def delete_collection(self, collection_name: str):
  39. # Delete the collection based on the collection name.
  40. return self.client.delete_collection(name=collection_name)
  41. def search(
  42. self, collection_name: str, vectors: list[list[float | int]], limit: int
  43. ) -> Optional[SearchResult]:
  44. # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
  45. try:
  46. collection = self.client.get_collection(name=collection_name)
  47. if collection:
  48. result = collection.query(
  49. query_embeddings=vectors,
  50. n_results=limit,
  51. )
  52. return SearchResult(
  53. **{
  54. "ids": result["ids"],
  55. "distances": result["distances"],
  56. "documents": result["documents"],
  57. "metadatas": result["metadatas"],
  58. }
  59. )
  60. return None
  61. except Exception as e:
  62. return None
  63. def query(
  64. self, collection_name: str, filter: dict, limit: Optional[int] = None
  65. ) -> Optional[GetResult]:
  66. # Query the items from the collection based on the filter.
  67. try:
  68. collection = self.client.get_collection(name=collection_name)
  69. if collection:
  70. result = collection.get(
  71. where=filter,
  72. limit=limit,
  73. )
  74. return GetResult(
  75. **{
  76. "ids": [result["ids"]],
  77. "documents": [result["documents"]],
  78. "metadatas": [result["metadatas"]],
  79. }
  80. )
  81. return None
  82. except Exception as e:
  83. print(e)
  84. return None
  85. def get(self, collection_name: str) -> Optional[GetResult]:
  86. # Get all the items in the collection.
  87. collection = self.client.get_collection(name=collection_name)
  88. if collection:
  89. result = collection.get()
  90. return GetResult(
  91. **{
  92. "ids": [result["ids"]],
  93. "documents": [result["documents"]],
  94. "metadatas": [result["metadatas"]],
  95. }
  96. )
  97. return None
  98. def insert(self, collection_name: str, items: list[VectorItem]):
  99. # Insert the items into the collection, if the collection does not exist, it will be created.
  100. collection = self.client.get_or_create_collection(
  101. name=collection_name,
  102. metadata={"hnsw:space": "cosine"}
  103. )
  104. ids = [item["id"] for item in items]
  105. documents = [item["text"] for item in items]
  106. embeddings = [item["vector"] for item in items]
  107. metadatas = [item["metadata"] for item in items]
  108. for batch in create_batches(
  109. api=self.client,
  110. documents=documents,
  111. embeddings=embeddings,
  112. ids=ids,
  113. metadatas=metadatas,
  114. ):
  115. collection.add(*batch)
  116. def upsert(self, collection_name: str, items: list[VectorItem]):
  117. # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
  118. collection = self.client.get_or_create_collection(name=collection_name)
  119. ids = [item["id"] for item in items]
  120. documents = [item["text"] for item in items]
  121. embeddings = [item["vector"] for item in items]
  122. metadatas = [item["metadata"] for item in items]
  123. collection.upsert(
  124. ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas
  125. )
  126. def delete(
  127. self,
  128. collection_name: str,
  129. ids: Optional[list[str]] = None,
  130. filter: Optional[dict] = None,
  131. ):
  132. # Delete the items from the collection based on the ids.
  133. collection = self.client.get_collection(name=collection_name)
  134. if collection:
  135. if ids:
  136. collection.delete(ids=ids)
  137. elif filter:
  138. collection.delete(where=filter)
  139. def reset(self):
  140. # Resets the database. This will delete all collections and item entries.
  141. return self.client.reset()