Browse Source

feat: milvus support

Timothy J. Baek 7 months ago
parent
commit
4775fe43d8

+ 0 - 1
backend/open_webui/apps/rag/main.py

@@ -1010,7 +1010,6 @@ def store_docs_in_vector_db(
             app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
         )
 
-        VECTOR_DB_CLIENT.create_collection(collection_name=collection_name)
         VECTOR_DB_CLIENT.insert(
             collection_name=collection_name,
             items=[

+ 2 - 6
backend/open_webui/apps/rag/vector/dbs/chroma.py

@@ -41,10 +41,6 @@ class ChromaClient:
         collections = self.client.list_collections()
         return [collection.name for collection in collections]
 
-    def create_collection(self, collection_name: str):
-        # Create a new collection based on the collection name.
-        return self.client.create_collection(name=collection_name)
-
     def delete_collection(self, collection_name: str):
         # Delete the collection based on the collection name.
         return self.client.delete_collection(name=collection_name)
@@ -76,7 +72,7 @@ class ChromaClient:
         return None
 
     def insert(self, collection_name: str, items: list[VectorItem]):
-        # Insert the items into the collection.
+        # Insert the items into the collection, if the collection does not exist, it will be created.
         collection = self.client.get_or_create_collection(name=collection_name)
 
         ids = [item["id"] for item in items]
@@ -94,7 +90,7 @@ class ChromaClient:
             collection.add(*batch)
 
     def upsert(self, collection_name: str, items: list[VectorItem]):
-        # Update the items in the collection, if the items are not present, insert them.
+        # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
         collection = self.client.get_or_create_collection(name=collection_name)
 
         ids = [item["id"] for item in items]

+ 149 - 13
backend/open_webui/apps/rag/vector/dbs/milvus.py

@@ -1,39 +1,175 @@
-from pymilvus import MilvusClient as Milvus
+from pymilvus import MilvusClient as Client
+from pymilvus import FieldSchema, DataType
+import json
 
 from typing import Optional
 
 from open_webui.apps.rag.vector.main import VectorItem, QueryResult
+from open_webui.config import (
+    MILVUS_URI,
+)
 
 
 class MilvusClient:
     def __init__(self):
-        self.client = Milvus()
+        self.collection_prefix = "open_webui"
+        self.client = Client(uri=MILVUS_URI)
 
-    def list_collections(self) -> list[str]:
-        pass
+    def _result_to_query_result(self, result) -> QueryResult:
+        print(result)
+
+        ids = []
+        distances = []
+        documents = []
+        metadatas = []
+
+        for match in result:
+            _ids = []
+            _distances = []
+            _documents = []
+            _metadatas = []
+
+            for item in match:
+                _ids.append(item.get("id"))
+                _distances.append(item.get("distance"))
+                _documents.append(item.get("entity", {}).get("data", {}).get("text"))
+                _metadatas.append(item.get("entity", {}).get("metadata"))
+
+            ids.append(_ids)
+            distances.append(_distances)
+            documents.append(_documents)
+            metadatas.append(_metadatas)
+
+        return {
+            "ids": ids,
+            "distances": distances,
+            "documents": documents,
+            "metadatas": metadatas,
+        }
+
+    def _create_collection(self, collection_name: str, dimension: int):
+        schema = self.client.create_schema(
+            auto_id=False,
+            enable_dynamic_field=True,
+        )
+        schema.add_field(
+            field_name="id",
+            datatype=DataType.VARCHAR,
+            is_primary=True,
+            max_length=65535,
+        )
+        schema.add_field(
+            field_name="vector",
+            datatype=DataType.FLOAT_VECTOR,
+            dim=dimension,
+            description="vector",
+        )
+        schema.add_field(field_name="data", datatype=DataType.JSON, description="data")
+        schema.add_field(
+            field_name="metadata", datatype=DataType.JSON, description="metadata"
+        )
+
+        index_params = self.client.prepare_index_params()
+        index_params.add_index(
+            field_name="vector", index_type="HNSW", metric_type="COSINE", params={}
+        )
 
-    def create_collection(self, collection_name: str):
-        pass
+        self.client.create_collection(
+            collection_name=f"{self.collection_prefix}_{collection_name}",
+            schema=schema,
+            index_params=index_params,
+        )
+
+    def list_collections(self) -> list[str]:
+        # List all the collections in the database.
+        return [
+            collection[len(self.collection_prefix) :]
+            for collection in self.client.list_collections()
+            if collection.startswith(self.collection_prefix)
+        ]
 
     def delete_collection(self, collection_name: str):
-        pass
+        # Delete the collection based on the collection name.
+        return self.client.drop_collection(
+            collection_name=f"{self.collection_prefix}_{collection_name}"
+        )
 
     def search(
         self, collection_name: str, vectors: list[list[float | int]], limit: int
     ) -> Optional[QueryResult]:
-        pass
+        # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
+        result = self.client.search(
+            collection_name=f"{self.collection_prefix}_{collection_name}",
+            data=vectors,
+            limit=limit,
+            output_fields=["data", "metadata"],
+        )
+
+        return self._result_to_query_result(result)
 
     def get(self, collection_name: str) -> Optional[QueryResult]:
-        pass
+        # Get all the items in the collection.
+        result = self.client.query(
+            collection_name=f"{self.collection_prefix}_{collection_name}",
+        )
+        return self._result_to_query_result(result)
 
     def insert(self, collection_name: str, items: list[VectorItem]):
-        pass
+        # Insert the items into the collection, if the collection does not exist, it will be created.
+        if not self.client.has_collection(
+            collection_name=f"{self.collection_prefix}_{collection_name}"
+        ):
+            self._create_collection(
+                collection_name=collection_name, dimension=len(items[0]["vector"])
+            )
+
+        return self.client.insert(
+            collection_name=f"{self.collection_prefix}_{collection_name}",
+            data=[
+                {
+                    "id": item["id"],
+                    "vector": item["vector"],
+                    "data": {"text": item["text"]},
+                    "metadata": item["metadata"],
+                }
+                for item in items
+            ],
+        )
 
     def upsert(self, collection_name: str, items: list[VectorItem]):
-        pass
+        # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
+        if not self.client.has_collection(
+            collection_name=f"{self.collection_prefix}_{collection_name}"
+        ):
+            self._create_collection(
+                collection_name=collection_name, dimension=len(items[0]["vector"])
+            )
+
+        return self.client.upsert(
+            collection_name=f"{self.collection_prefix}_{collection_name}",
+            data=[
+                {
+                    "id": item["id"],
+                    "vector": item["vector"],
+                    "data": {"text": item["text"]},
+                    "metadata": item["metadata"],
+                }
+                for item in items
+            ],
+        )
 
     def delete(self, collection_name: str, ids: list[str]):
-        pass
+        # Delete the items from the collection based on the ids.
+
+        return self.client.delete(
+            collection_name=f"{self.collection_prefix}_{collection_name}",
+            ids=ids,
+        )
 
     def reset(self):
-        pass
+        # Resets the database. This will delete all collections and item entries.
+
+        collection_names = self.client.list_collections()
+        for collection_name in collection_names:
+            if collection_name.startswith(self.collection_prefix):
+                self.client.drop_collection(collection_name=collection_name)

+ 4 - 0
backend/open_webui/config.py

@@ -910,6 +910,10 @@ else:
 CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
 # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
 
+# Milvus
+
+MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db")
+
 ####################################
 # RAG
 ####################################

+ 2 - 0
backend/requirements.txt

@@ -40,6 +40,8 @@ langchain-chroma==0.1.2
 
 fake-useragent==1.5.1
 chromadb==0.5.5
+pymilvus==2.4.6
+
 sentence-transformers==3.0.1
 pypdf==4.3.1
 docx2txt==0.8

+ 1 - 0
pyproject.toml

@@ -47,6 +47,7 @@ dependencies = [
 
     "fake-useragent==1.5.1",
     "chromadb==0.5.5",
+    "pymilvus==2.4.6",
     "sentence-transformers==3.0.1",
     "pypdf==4.3.1",
     "docx2txt==0.8",