浏览代码

Merge pull request #5312 from open-webui/multiple-vector-dbs

feat: various vector db support
Timothy Jaeryang Baek 7 月之前
父节点
当前提交
c7fc17da69

+ 16 - 19
backend/open_webui/apps/rag/main.py

@@ -96,7 +96,6 @@ from open_webui.utils.misc import (
 from open_webui.utils.utils import get_admin_user, get_verified_user
 from open_webui.apps.rag.vector.connector import VECTOR_DB_CLIENT
 
-from chromadb.utils.batch_utils import create_batches
 from langchain.text_splitter import RecursiveCharacterTextSplitter
 from langchain_community.document_loaders import (
     BSHTMLLoader,
@@ -998,14 +997,11 @@ def store_docs_in_vector_db(
 
     try:
         if overwrite:
-            for collection in VECTOR_DB_CLIENT.list_collections():
-                if collection_name == collection.name:
-                    log.info(f"deleting existing collection {collection_name}")
-                    VECTOR_DB_CLIENT.delete_collection(name=collection_name)
+            if collection_name in VECTOR_DB_CLIENT.list_collections():
+                log.info(f"deleting existing collection {collection_name}")
+                VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name)
 
-        collection = VECTOR_DB_CLIENT.create_collection(name=collection_name)
-
-        embedding_func = get_embedding_function(
+        embedding_function = get_embedding_function(
             app.state.config.RAG_EMBEDDING_ENGINE,
             app.state.config.RAG_EMBEDDING_MODEL,
             app.state.sentence_transformer_ef,
@@ -1014,17 +1010,18 @@ def store_docs_in_vector_db(
             app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
         )
 
-        embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
-        embeddings = embedding_func(embedding_texts)
-
-        for batch in create_batches(
-            api=VECTOR_DB_CLIENT,
-            ids=[str(uuid.uuid4()) for _ in texts],
-            metadatas=metadatas,
-            embeddings=embeddings,
-            documents=texts,
-        ):
-            collection.add(*batch)
+        VECTOR_DB_CLIENT.insert(
+            collection_name=collection_name,
+            items=[
+                {
+                    "id": str(uuid.uuid4()),
+                    "text": text,
+                    "vector": embedding_function(text.replace("\n", " ")),
+                    "metadata": metadatas[idx],
+                }
+                for idx, text in enumerate(texts)
+            ],
+        )
 
         return True
     except Exception as e:

+ 52 - 52
backend/open_webui/apps/rag/utils.py

@@ -24,6 +24,44 @@ log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 
 
+from typing import Any
+
+from langchain_core.callbacks import CallbackManagerForRetrieverRun
+from langchain_core.retrievers import BaseRetriever
+
+
+class VectorSearchRetriever(BaseRetriever):
+    collection_name: Any
+    embedding_function: Any
+    top_k: int
+
+    def _get_relevant_documents(
+        self,
+        query: str,
+        *,
+        run_manager: CallbackManagerForRetrieverRun,
+    ) -> list[Document]:
+        result = VECTOR_DB_CLIENT.search(
+            collection_name=self.collection_name,
+            vectors=[self.embedding_function(query)],
+            limit=self.top_k,
+        )
+
+        ids = result["ids"][0]
+        metadatas = result["metadatas"][0]
+        documents = result["documents"][0]
+
+        results = []
+        for idx in range(len(ids)):
+            results.append(
+                Document(
+                    metadata=metadatas[idx],
+                    page_content=documents[idx],
+                )
+            )
+        return results
+
+
 def query_doc(
     collection_name: str,
     query: str,
@@ -31,15 +69,18 @@ def query_doc(
     k: int,
 ):
     try:
-        result = VECTOR_DB_CLIENT.query_collection(
-            name=collection_name,
-            query_embeddings=embedding_function(query),
-            k=k,
+        result = VECTOR_DB_CLIENT.search(
+            collection_name=collection_name,
+            vectors=[embedding_function(query)],
+            limit=k,
         )
 
+        print("result", result)
+
         log.info(f"query_doc:result {result}")
         return result
     except Exception as e:
+        print(e)
         raise e
 
 
@@ -52,25 +93,23 @@ def query_doc_with_hybrid_search(
     r: float,
 ):
     try:
-        collection = VECTOR_DB_CLIENT.get_collection(name=collection_name)
-        documents = collection.get()  # get all documents
+        result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
 
         bm25_retriever = BM25Retriever.from_texts(
-            texts=documents.get("documents"),
-            metadatas=documents.get("metadatas"),
+            texts=result.documents,
+            metadatas=result.metadatas,
         )
         bm25_retriever.k = k
 
-        chroma_retriever = ChromaRetriever(
-            collection=collection,
+        vector_search_retriever = VectorSearchRetriever(
+            collection_name=collection_name,
             embedding_function=embedding_function,
-            top_n=k,
+            top_k=k,
         )
 
         ensemble_retriever = EnsembleRetriever(
-            retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5]
+            retrievers=[bm25_retriever, vector_search_retriever], weights=[0.5, 0.5]
         )
-
         compressor = RerankCompressor(
             embedding_function=embedding_function,
             top_n=k,
@@ -394,45 +433,6 @@ def generate_openai_batch_embeddings(
         return None
 
 
-from typing import Any
-
-from langchain_core.callbacks import CallbackManagerForRetrieverRun
-from langchain_core.retrievers import BaseRetriever
-
-
-class ChromaRetriever(BaseRetriever):
-    collection: Any
-    embedding_function: Any
-    top_n: int
-
-    def _get_relevant_documents(
-        self,
-        query: str,
-        *,
-        run_manager: CallbackManagerForRetrieverRun,
-    ) -> list[Document]:
-        query_embeddings = self.embedding_function(query)
-
-        results = self.collection.query(
-            query_embeddings=[query_embeddings],
-            n_results=self.top_n,
-        )
-
-        ids = results["ids"][0]
-        metadatas = results["metadatas"][0]
-        documents = results["documents"][0]
-
-        results = []
-        for idx in range(len(ids)):
-            results.append(
-                Document(
-                    metadata=metadatas[idx],
-                    page_content=documents[idx],
-                )
-            )
-        return results
-
-
 import operator
 from typing import Optional, Sequence
 

+ 8 - 2
backend/open_webui/apps/rag/vector/connector.py

@@ -1,4 +1,10 @@
-from open_webui.apps.rag.vector.dbs.chroma import Chroma
+from open_webui.apps.rag.vector.dbs.chroma import ChromaClient
+from open_webui.apps.rag.vector.dbs.milvus import MilvusClient
+
+
 from open_webui.config import VECTOR_DB
 
-VECTOR_DB_CLIENT = Chroma()
+if VECTOR_DB == "milvus":
+    VECTOR_DB_CLIENT = MilvusClient()
+else:
+    VECTOR_DB_CLIENT = ChromaClient()

+ 69 - 14
backend/open_webui/apps/rag/vector/dbs/chroma.py

@@ -1,6 +1,10 @@
 import chromadb
 from chromadb import Settings
+from chromadb.utils.batch_utils import create_batches
 
+from typing import Optional
+
+from open_webui.apps.rag.vector.main import VectorItem, QueryResult
 from open_webui.config import (
     CHROMA_DATA_PATH,
     CHROMA_HTTP_HOST,
@@ -12,7 +16,7 @@ from open_webui.config import (
 )
 
 
-class Chroma:
+class ChromaClient:
     def __init__(self):
         if CHROMA_HTTP_HOST != "":
             self.client = chromadb.HttpClient(
@@ -32,27 +36,78 @@ class Chroma:
                 database=CHROMA_DATABASE,
             )
 
-    def query_collection(self, name, query_embeddings, k):
-        collection = self.client.get_collection(name=name)
+    def list_collections(self) -> list[str]:
+        # List all the collections in the database.
+        collections = self.client.list_collections()
+        return [collection.name for collection in collections]
+
+    def delete_collection(self, collection_name: str):
+        # Delete the collection based on the collection name.
+        return self.client.delete_collection(name=collection_name)
+
+    def search(
+        self, collection_name: str, vectors: list[list[float | int]], limit: int
+    ) -> Optional[QueryResult]:
+        # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
+        collection = self.client.get_collection(name=collection_name)
         if collection:
             result = collection.query(
-                query_embeddings=[query_embeddings],
-                n_results=k,
+                query_embeddings=vectors,
+                n_results=limit,
             )
-            return result
+
+            return {
+                "ids": result["ids"],
+                "distances": result["distances"],
+                "documents": result["documents"],
+                "metadatas": result["metadatas"],
+            }
         return None
 
-    def list_collections(self):
-        return self.client.list_collections()
+    def get(self, collection_name: str) -> Optional[QueryResult]:
+        # Get all the items in the collection.
+        collection = self.client.get_collection(name=collection_name)
+        if collection:
+            return collection.get()
+        return None
 
-    def create_collection(self, name):
-        return self.client.create_collection(name=name)
+    def insert(self, collection_name: str, items: list[VectorItem]):
+        # Insert the items into the collection, if the collection does not exist, it will be created.
+        collection = self.client.get_or_create_collection(name=collection_name)
 
-    def get_or_create_collection(self, name):
-        return self.client.get_or_create_collection(name=name)
+        ids = [item["id"] for item in items]
+        documents = [item["text"] for item in items]
+        embeddings = [item["vector"] for item in items]
+        metadatas = [item["metadata"] for item in items]
 
-    def delete_collection(self, name):
-        return self.client.delete_collection(name=name)
+        for batch in create_batches(
+            api=self.client,
+            documents=documents,
+            embeddings=embeddings,
+            ids=ids,
+            metadatas=metadatas,
+        ):
+            collection.add(*batch)
+
+    def upsert(self, collection_name: str, items: list[VectorItem]):
+        # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
+        collection = self.client.get_or_create_collection(name=collection_name)
+
+        ids = [item["id"] for item in items]
+        documents = [item["text"] for item in items]
+        embeddings = [item["vector"] for item in items]
+        metadata = [item["metadata"] for item in items]
+
+        collection.upsert(
+            ids=ids, documents=documents, embeddings=embeddings, metadata=metadata
+        )
+
+    def delete(self, collection_name: str, ids: list[str]):
+        # Delete the items from the collection based on the ids.
+        collection = self.client.get_collection(name=collection_name)
+        if collection:
+            collection.delete(ids=ids)
 
     def reset(self):
+        # Resets the database. This will delete all collections and item entries.
         return self.client.reset()

+ 175 - 0
backend/open_webui/apps/rag/vector/dbs/milvus.py

@@ -0,0 +1,175 @@
+from pymilvus import MilvusClient as Client
+from pymilvus import FieldSchema, DataType
+import json
+
+from typing import Optional
+
+from open_webui.apps.rag.vector.main import VectorItem, QueryResult
+from open_webui.config import (
+    MILVUS_URI,
+)
+
+
+class MilvusClient:
+    def __init__(self):
+        self.collection_prefix = "open_webui"
+        self.client = Client(uri=MILVUS_URI)
+
+    def _result_to_query_result(self, result) -> QueryResult:
+        print(result)
+
+        ids = []
+        distances = []
+        documents = []
+        metadatas = []
+
+        for match in result:
+            _ids = []
+            _distances = []
+            _documents = []
+            _metadatas = []
+
+            for item in match:
+                _ids.append(item.get("id"))
+                _distances.append(item.get("distance"))
+                _documents.append(item.get("entity", {}).get("data", {}).get("text"))
+                _metadatas.append(item.get("entity", {}).get("metadata"))
+
+            ids.append(_ids)
+            distances.append(_distances)
+            documents.append(_documents)
+            metadatas.append(_metadatas)
+
+        return {
+            "ids": ids,
+            "distances": distances,
+            "documents": documents,
+            "metadatas": metadatas,
+        }
+
+    def _create_collection(self, collection_name: str, dimension: int):
+        schema = self.client.create_schema(
+            auto_id=False,
+            enable_dynamic_field=True,
+        )
+        schema.add_field(
+            field_name="id",
+            datatype=DataType.VARCHAR,
+            is_primary=True,
+            max_length=65535,
+        )
+        schema.add_field(
+            field_name="vector",
+            datatype=DataType.FLOAT_VECTOR,
+            dim=dimension,
+            description="vector",
+        )
+        schema.add_field(field_name="data", datatype=DataType.JSON, description="data")
+        schema.add_field(
+            field_name="metadata", datatype=DataType.JSON, description="metadata"
+        )
+
+        index_params = self.client.prepare_index_params()
+        index_params.add_index(
+            field_name="vector", index_type="HNSW", metric_type="COSINE", params={}
+        )
+
+        self.client.create_collection(
+            collection_name=f"{self.collection_prefix}_{collection_name}",
+            schema=schema,
+            index_params=index_params,
+        )
+
+    def list_collections(self) -> list[str]:
+        # List all the collections in the database.
+        return [
+            collection[len(self.collection_prefix) :]
+            for collection in self.client.list_collections()
+            if collection.startswith(self.collection_prefix)
+        ]
+
+    def delete_collection(self, collection_name: str):
+        # Delete the collection based on the collection name.
+        return self.client.drop_collection(
+            collection_name=f"{self.collection_prefix}_{collection_name}"
+        )
+
+    def search(
+        self, collection_name: str, vectors: list[list[float | int]], limit: int
+    ) -> Optional[QueryResult]:
+        # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
+        result = self.client.search(
+            collection_name=f"{self.collection_prefix}_{collection_name}",
+            data=vectors,
+            limit=limit,
+            output_fields=["data", "metadata"],
+        )
+
+        return self._result_to_query_result(result)
+
+    def get(self, collection_name: str) -> Optional[QueryResult]:
+        # Get all the items in the collection.
+        result = self.client.query(
+            collection_name=f"{self.collection_prefix}_{collection_name}",
+        )
+        return self._result_to_query_result(result)
+
+    def insert(self, collection_name: str, items: list[VectorItem]):
+        # Insert the items into the collection, if the collection does not exist, it will be created.
+        if not self.client.has_collection(
+            collection_name=f"{self.collection_prefix}_{collection_name}"
+        ):
+            self._create_collection(
+                collection_name=collection_name, dimension=len(items[0]["vector"])
+            )
+
+        return self.client.insert(
+            collection_name=f"{self.collection_prefix}_{collection_name}",
+            data=[
+                {
+                    "id": item["id"],
+                    "vector": item["vector"],
+                    "data": {"text": item["text"]},
+                    "metadata": item["metadata"],
+                }
+                for item in items
+            ],
+        )
+
+    def upsert(self, collection_name: str, items: list[VectorItem]):
+        # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
+        if not self.client.has_collection(
+            collection_name=f"{self.collection_prefix}_{collection_name}"
+        ):
+            self._create_collection(
+                collection_name=collection_name, dimension=len(items[0]["vector"])
+            )
+
+        return self.client.upsert(
+            collection_name=f"{self.collection_prefix}_{collection_name}",
+            data=[
+                {
+                    "id": item["id"],
+                    "vector": item["vector"],
+                    "data": {"text": item["text"]},
+                    "metadata": item["metadata"],
+                }
+                for item in items
+            ],
+        )
+
+    def delete(self, collection_name: str, ids: list[str]):
+        # Delete the items from the collection based on the ids.
+
+        return self.client.delete(
+            collection_name=f"{self.collection_prefix}_{collection_name}",
+            ids=ids,
+        )
+
+    def reset(self):
+        # Resets the database. This will delete all collections and item entries.
+
+        collection_names = self.client.list_collections()
+        for collection_name in collection_names:
+            if collection_name.startswith(self.collection_prefix):
+                self.client.drop_collection(collection_name=collection_name)

+ 16 - 0
backend/open_webui/apps/rag/vector/main.py

@@ -0,0 +1,16 @@
+from pydantic import BaseModel
+from typing import Optional, List, Any
+
+
+class VectorItem(BaseModel):
+    id: str
+    text: str
+    vector: List[float | int]
+    metadata: Any
+
+
+class QueryResult(BaseModel):
+    ids: Optional[List[List[str]]]
+    distances: Optional[List[List[float | int]]]
+    documents: Optional[List[List[str]]]
+    metadatas: Optional[List[List[Any]]]

+ 44 - 40
backend/open_webui/apps/webui/routers/memories.py

@@ -50,16 +50,17 @@ async def add_memory(
     user=Depends(get_verified_user),
 ):
     memory = Memories.insert_new_memory(user.id, form_data.content)
-    memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
 
-    collection = VECTOR_DB_CLIENT.get_or_create_collection(
-        name=f"user-memory-{user.id}"
-    )
-    collection.upsert(
-        documents=[memory.content],
-        ids=[memory.id],
-        embeddings=[memory_embedding],
-        metadatas=[{"created_at": memory.created_at}],
+    VECTOR_DB_CLIENT.upsert(
+        collection_name=f"user-memory-{user.id}",
+        items=[
+            {
+                "id": memory.id,
+                "text": memory.content,
+                "vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
+                "metadata": {"created_at": memory.created_at},
+            }
+        ],
     )
 
     return memory
@@ -79,14 +80,10 @@ class QueryMemoryForm(BaseModel):
 async def query_memory(
     request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user)
 ):
-    query_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content)
-    collection = VECTOR_DB_CLIENT.get_or_create_collection(
-        name=f"user-memory-{user.id}"
-    )
-
-    results = collection.query(
-        query_embeddings=[query_embedding],
-        n_results=form_data.k,  # how many results to return
+    results = VECTOR_DB_CLIENT.search(
+        name=f"user-memory-{user.id}",
+        vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content)],
+        limit=form_data.k,
     )
 
     return results
@@ -100,18 +97,24 @@ async def reset_memory_from_vector_db(
     request: Request, user=Depends(get_verified_user)
 ):
     VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}")
-    collection = VECTOR_DB_CLIENT.get_or_create_collection(
-        name=f"user-memory-{user.id}"
-    )
 
     memories = Memories.get_memories_by_user_id(user.id)
-    for memory in memories:
-        memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
-        collection.upsert(
-            documents=[memory.content],
-            ids=[memory.id],
-            embeddings=[memory_embedding],
-        )
+    VECTOR_DB_CLIENT.upsert(
+        collection_name=f"user-memory-{user.id}",
+        items=[
+            {
+                "id": memory.id,
+                "text": memory.content,
+                "vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
+                "metadata": {
+                    "created_at": memory.created_at,
+                    "updated_at": memory.updated_at,
+                },
+            }
+            for memory in memories
+        ],
+    )
+
     return True
 
 
@@ -151,16 +154,18 @@ async def update_memory_by_id(
         raise HTTPException(status_code=404, detail="Memory not found")
 
     if form_data.content is not None:
-        memory_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content)
-        collection = VECTOR_DB_CLIENT.get_or_create_collection(
-            name=f"user-memory-{user.id}"
-        )
-        collection.upsert(
-            documents=[form_data.content],
-            ids=[memory.id],
-            embeddings=[memory_embedding],
-            metadatas=[
-                {"created_at": memory.created_at, "updated_at": memory.updated_at}
+        VECTOR_DB_CLIENT.upsert(
+            collection_name=f"user-memory-{user.id}",
+            items=[
+                {
+                    "id": memory.id,
+                    "text": memory.content,
+                    "vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
+                    "metadata": {
+                        "created_at": memory.created_at,
+                        "updated_at": memory.updated_at,
+                    },
+                }
             ],
         )
 
@@ -177,10 +182,9 @@ async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)):
     result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
 
     if result:
-        collection = VECTOR_DB_CLIENT.get_or_create_collection(
-            name=f"user-memory-{user.id}"
+        VECTOR_DB_CLIENT.delete(
+            collection_name=f"user-memory-{user.id}", ids=[memory_id]
         )
-        collection.delete(ids=[memory_id])
         return True
 
     return False

+ 4 - 0
backend/open_webui/config.py

@@ -910,6 +910,10 @@ else:
 CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
 # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
 
+# Milvus
+
+MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db")
+
 ####################################
 # RAG
 ####################################

+ 2 - 0
backend/requirements.txt

@@ -40,6 +40,8 @@ langchain-chroma==0.1.2
 
 fake-useragent==1.5.1
 chromadb==0.5.5
+pymilvus==2.4.6
+
 sentence-transformers==3.0.1
 pypdf==4.3.1
 docx2txt==0.8

+ 1 - 0
pyproject.toml

@@ -47,6 +47,7 @@ dependencies = [
 
     "fake-useragent==1.5.1",
     "chromadb==0.5.5",
+    "pymilvus==2.4.6",
     "sentence-transformers==3.0.1",
     "pypdf==4.3.1",
     "docx2txt==0.8",