chroma.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import chromadb
  2. from chromadb import Settings
  3. from chromadb.utils.batch_utils import create_batches
  4. from typing import Optional
  5. from open_webui.apps.rag.vector.main import VectorItem, QueryResult
  6. from open_webui.config import (
  7. CHROMA_DATA_PATH,
  8. CHROMA_HTTP_HOST,
  9. CHROMA_HTTP_PORT,
  10. CHROMA_HTTP_HEADERS,
  11. CHROMA_HTTP_SSL,
  12. CHROMA_TENANT,
  13. CHROMA_DATABASE,
  14. )
  15. class ChromaClient:
  16. def __init__(self):
  17. if CHROMA_HTTP_HOST != "":
  18. self.client = chromadb.HttpClient(
  19. host=CHROMA_HTTP_HOST,
  20. port=CHROMA_HTTP_PORT,
  21. headers=CHROMA_HTTP_HEADERS,
  22. ssl=CHROMA_HTTP_SSL,
  23. tenant=CHROMA_TENANT,
  24. database=CHROMA_DATABASE,
  25. settings=Settings(allow_reset=True, anonymized_telemetry=False),
  26. )
  27. else:
  28. self.client = chromadb.PersistentClient(
  29. path=CHROMA_DATA_PATH,
  30. settings=Settings(allow_reset=True, anonymized_telemetry=False),
  31. tenant=CHROMA_TENANT,
  32. database=CHROMA_DATABASE,
  33. )
  34. def list_collections(self) -> list[str]:
  35. # List all the collections in the database.
  36. collections = self.client.list_collections()
  37. return [collection.name for collection in collections]
  38. def create_collection(self, collection_name: str):
  39. # Create a new collection based on the collection name.
  40. return self.client.create_collection(name=collection_name)
  41. def delete_collection(self, collection_name: str):
  42. # Delete the collection based on the collection name.
  43. return self.client.delete_collection(name=collection_name)
  44. def search(
  45. self, collection_name: str, vectors: list[list[float | int]], limit: int
  46. ) -> Optional[QueryResult]:
  47. # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
  48. collection = self.client.get_collection(name=collection_name)
  49. if collection:
  50. result = collection.query(
  51. query_embeddings=vectors,
  52. n_results=limit,
  53. )
  54. return {
  55. "ids": result["ids"],
  56. "distances": result["distances"],
  57. "documents": result["documents"],
  58. "metadatas": result["metadatas"],
  59. }
  60. return None
  61. def get(self, collection_name: str) -> Optional[QueryResult]:
  62. # Get all the items in the collection.
  63. collection = self.client.get_collection(name=collection_name)
  64. if collection:
  65. return collection.get()
  66. return None
  67. def insert(self, collection_name: str, items: list[VectorItem]):
  68. # Insert the items into the collection.
  69. collection = self.client.get_or_create_collection(name=collection_name)
  70. ids = [item["id"] for item in items]
  71. documents = [item["text"] for item in items]
  72. embeddings = [item["vector"] for item in items]
  73. metadatas = [item["metadata"] for item in items]
  74. for batch in create_batches(
  75. api=self.client,
  76. documents=documents,
  77. embeddings=embeddings,
  78. ids=ids,
  79. metadatas=metadatas,
  80. ):
  81. collection.add(*batch)
  82. def upsert(self, collection_name: str, items: list[VectorItem]):
  83. # Update the items in the collection, if the items are not present, insert them.
  84. collection = self.client.get_or_create_collection(name=collection_name)
  85. ids = [item["id"] for item in items]
  86. documents = [item["text"] for item in items]
  87. embeddings = [item["vector"] for item in items]
  88. metadata = [item["metadata"] for item in items]
  89. collection.upsert(
  90. ids=ids, documents=documents, embeddings=embeddings, metadata=metadata
  91. )
  92. def delete(self, collection_name: str, ids: list[str]):
  93. # Delete the items from the collection based on the ids.
  94. collection = self.client.get_collection(name=collection_name)
  95. if collection:
  96. collection.delete(ids=ids)
  97. def reset(self):
  98. # Resets the database. This will delete all collections and item entries.
  99. return self.client.reset()