123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178 |
- import chromadb
- import logging
- from chromadb import Settings
- from chromadb.utils.batch_utils import create_batches
- from typing import Optional
- from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
- from open_webui.config import (
- CHROMA_DATA_PATH,
- CHROMA_HTTP_HOST,
- CHROMA_HTTP_PORT,
- CHROMA_HTTP_HEADERS,
- CHROMA_HTTP_SSL,
- CHROMA_TENANT,
- CHROMA_DATABASE,
- CHROMA_CLIENT_AUTH_PROVIDER,
- CHROMA_CLIENT_AUTH_CREDENTIALS,
- )
- from open_webui.env import SRC_LOG_LEVELS
- log = logging.getLogger(__name__)
- log.setLevel(SRC_LOG_LEVELS["RAG"])
- class ChromaClient:
- def __init__(self):
- settings_dict = {
- "allow_reset": True,
- "anonymized_telemetry": False,
- }
- if CHROMA_CLIENT_AUTH_PROVIDER is not None:
- settings_dict["chroma_client_auth_provider"] = CHROMA_CLIENT_AUTH_PROVIDER
- if CHROMA_CLIENT_AUTH_CREDENTIALS is not None:
- settings_dict["chroma_client_auth_credentials"] = (
- CHROMA_CLIENT_AUTH_CREDENTIALS
- )
- if CHROMA_HTTP_HOST != "":
- self.client = chromadb.HttpClient(
- host=CHROMA_HTTP_HOST,
- port=CHROMA_HTTP_PORT,
- headers=CHROMA_HTTP_HEADERS,
- ssl=CHROMA_HTTP_SSL,
- tenant=CHROMA_TENANT,
- database=CHROMA_DATABASE,
- settings=Settings(**settings_dict),
- )
- else:
- self.client = chromadb.PersistentClient(
- path=CHROMA_DATA_PATH,
- settings=Settings(**settings_dict),
- tenant=CHROMA_TENANT,
- database=CHROMA_DATABASE,
- )
- def has_collection(self, collection_name: str) -> bool:
- # Check if the collection exists based on the collection name.
- collection_names = self.client.list_collections()
- return collection_name in collection_names
- def delete_collection(self, collection_name: str):
- # Delete the collection based on the collection name.
- return self.client.delete_collection(name=collection_name)
- def search(
- self, collection_name: str, vectors: list[list[float | int]], limit: int
- ) -> Optional[SearchResult]:
- # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
- try:
- collection = self.client.get_collection(name=collection_name)
- if collection:
- result = collection.query(
- query_embeddings=vectors,
- n_results=limit,
- )
- return SearchResult(
- **{
- "ids": result["ids"],
- "distances": result["distances"],
- "documents": result["documents"],
- "metadatas": result["metadatas"],
- }
- )
- return None
- except Exception as e:
- return None
- def query(
- self, collection_name: str, filter: dict, limit: Optional[int] = None
- ) -> Optional[GetResult]:
- # Query the items from the collection based on the filter.
- try:
- collection = self.client.get_collection(name=collection_name)
- if collection:
- result = collection.get(
- where=filter,
- limit=limit,
- )
- return GetResult(
- **{
- "ids": [result["ids"]],
- "documents": [result["documents"]],
- "metadatas": [result["metadatas"]],
- }
- )
- return None
- except:
- return None
- def get(self, collection_name: str) -> Optional[GetResult]:
- # Get all the items in the collection.
- collection = self.client.get_collection(name=collection_name)
- if collection:
- result = collection.get()
- return GetResult(
- **{
- "ids": [result["ids"]],
- "documents": [result["documents"]],
- "metadatas": [result["metadatas"]],
- }
- )
- return None
- def insert(self, collection_name: str, items: list[VectorItem]):
- # Insert the items into the collection, if the collection does not exist, it will be created.
- collection = self.client.get_or_create_collection(
- name=collection_name, metadata={"hnsw:space": "cosine"}
- )
- ids = [item["id"] for item in items]
- documents = [item["text"] for item in items]
- embeddings = [item["vector"] for item in items]
- metadatas = [item["metadata"] for item in items]
- for batch in create_batches(
- api=self.client,
- documents=documents,
- embeddings=embeddings,
- ids=ids,
- metadatas=metadatas,
- ):
- collection.add(*batch)
- def upsert(self, collection_name: str, items: list[VectorItem]):
- # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
- collection = self.client.get_or_create_collection(
- name=collection_name, metadata={"hnsw:space": "cosine"}
- )
- ids = [item["id"] for item in items]
- documents = [item["text"] for item in items]
- embeddings = [item["vector"] for item in items]
- metadatas = [item["metadata"] for item in items]
- collection.upsert(
- ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas
- )
- def delete(
- self,
- collection_name: str,
- ids: Optional[list[str]] = None,
- filter: Optional[dict] = None,
- ):
- # Delete the items from the collection based on the ids.
- collection = self.client.get_collection(name=collection_name)
- if collection:
- if ids:
- collection.delete(ids=ids)
- elif filter:
- collection.delete(where=filter)
- def reset(self):
- # Resets the database. This will delete all collections and item entries.
- return self.client.reset()
|