chroma.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import chromadb
  2. from chromadb import Settings
  3. from open_webui.config import (
  4. CHROMA_DATA_PATH,
  5. CHROMA_HTTP_HOST,
  6. CHROMA_HTTP_PORT,
  7. CHROMA_HTTP_HEADERS,
  8. CHROMA_HTTP_SSL,
  9. CHROMA_TENANT,
  10. CHROMA_DATABASE,
  11. )
  12. class Chroma:
  13. def __init__(self):
  14. if CHROMA_HTTP_HOST != "":
  15. self.client = chromadb.HttpClient(
  16. host=CHROMA_HTTP_HOST,
  17. port=CHROMA_HTTP_PORT,
  18. headers=CHROMA_HTTP_HEADERS,
  19. ssl=CHROMA_HTTP_SSL,
  20. tenant=CHROMA_TENANT,
  21. database=CHROMA_DATABASE,
  22. settings=Settings(allow_reset=True, anonymized_telemetry=False),
  23. )
  24. else:
  25. self.client = chromadb.PersistentClient(
  26. path=CHROMA_DATA_PATH,
  27. settings=Settings(allow_reset=True, anonymized_telemetry=False),
  28. tenant=CHROMA_TENANT,
  29. database=CHROMA_DATABASE,
  30. )
  31. def query_collection(self, name, query_embeddings, k):
  32. collection = self.client.get_collection(name=name)
  33. if collection:
  34. result = collection.query(
  35. query_embeddings=[query_embeddings],
  36. n_results=k,
  37. )
  38. return result
  39. return None
  40. def list_collections(self):
  41. return self.client.list_collections()
  42. def create_collection(self, name):
  43. return self.client.create_collection(name=name)
  44. def get_or_create_collection(self, name):
  45. return self.client.get_or_create_collection(name=name)
  46. def delete_collection(self, name):
  47. return self.client.delete_collection(name=name)
  48. def reset(self):
  49. return self.client.reset()