Browse Source

Merge pull request #6598 from dtaivpp/main

feat: OpenSearch vector db support
Timothy Jaeryang Baek 5 months ago
parent
commit
6c1d0a8e39

+ 4 - 0
backend/open_webui/apps/retrieval/vector/connector.py

@@ -8,6 +8,10 @@ elif VECTOR_DB == "qdrant":
     from open_webui.apps.retrieval.vector.dbs.qdrant import QdrantClient
     from open_webui.apps.retrieval.vector.dbs.qdrant import QdrantClient
 
 
     VECTOR_DB_CLIENT = QdrantClient()
     VECTOR_DB_CLIENT = QdrantClient()
+elif VECTOR_DB == "opensearch":
+    from open_webui.apps.retrieval.vector.dbs.opensearch import OpenSearchClient
+
+    VECTOR_DB_CLIENT = OpenSearchClient()
 else:
 else:
     from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient
     from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient
 
 

+ 152 - 0
backend/open_webui/apps/retrieval/vector/dbs/opensearch.py

@@ -0,0 +1,152 @@
+from opensearchpy import OpenSearch
+from typing import Optional
+
+from open_webui.apps.rag.vector.main import VectorItem, SearchResult, GetResult
+from open_webui.config import (
+    OPENSEARCH_URI,
+    OPENSEARCH_SSL,
+    OPENSEARCH_CERT_VERIFY,
+    OPENSEARCH_USERNAME,
+    OPENSEARCH_PASSWORD
+)
+
+class OpenSearchClient:
+    def __init__(self):
+        self.index_prefix = "open_webui"
+        self.client = OpenSearch(
+            hosts=[OPENSEARCH_URI],
+            use_ssl=OPENSEARCH_SSL,
+            verify_certs=OPENSEARCH_CERT_VERIFY,
+            http_auth=(OPENSEARCH_USERNAME, OPENSEARCH_PASSWORD),
+        )
+
+    def _result_to_get_result(self, result) -> GetResult:
+        ids = []
+        documents = []
+        metadatas = []
+
+        for hit in result['hits']['hits']:
+            ids.append(hit['_id'])
+            documents.append(hit['_source'].get("text"))
+            metadatas.append(hit['_source'].get("metadata"))
+
+        return GetResult(ids=ids, documents=documents, metadatas=metadatas)
+
+    def _result_to_search_result(self, result) -> SearchResult:
+        ids = []
+        distances = []
+        documents = []
+        metadatas = []
+
+        for hit in result['hits']['hits']:
+            ids.append(hit['_id'])
+            distances.append(hit['_score'])
+            documents.append(hit['_source'].get("text"))
+            metadatas.append(hit['_source'].get("metadata"))
+
+        return SearchResult(ids=ids, distances=distances, documents=documents, metadatas=metadatas)
+
+    def _create_index(self, index_name: str, dimension: int):
+        body = {
+            "mappings": {
+                "properties": {
+                    "id": {"type": "keyword"},
+                    "vector": {
+                            "type": "dense_vector",
+                            "dims": dimension,  # Adjust based on your vector dimensions
+                            "index": true,
+                            "similarity": "faiss",
+                            "method": {
+                            "name": "hnsw",
+                            "space_type": "ip",  # Use inner product to approximate cosine similarity
+                            "engine": "faiss",
+                            "ef_construction": 128,
+                            "m": 16
+                        }
+                    },
+                    "text": {"type": "text"},
+                    "metadata": {"type": "object"}
+                }
+            }
+        }
+        self.client.indices.create(index=f"{self.index_prefix}_{index_name}", body=body)
+
+    def _create_batches(self, items: list[VectorItem], batch_size=100):
+        for i in range(0, len(items), batch_size):
+            yield items[i:i + batch_size]
+ 
+    def has_collection(self, index_name: str) -> bool:
+        # has_collection here means has index. 
+        # We are simply adapting to the norms of the other DBs.
+        return self.client.indices.exists(index=f"{self.index_prefix}_{index_name}")
+
+    def delete_colleciton(self, index_name: str):
+        # delete_collection here means delete index. 
+        # We are simply adapting to the norms of the other DBs.
+        self.client.indices.delete(index=f"{self.index_prefix}_{index_name}")
+
+    def search(self, index_name: str, vectors: list[list[float]], limit: int) -> Optional[SearchResult]:
+        query = {
+            "size": limit,
+            "_source": ["text", "metadata"],
+            "query": {
+                "script_score": {
+                    "query": {"match_all": {}},
+                    "script": {
+                        "source": "cosineSimilarity(params.vector, 'vector') + 1.0",
+                        "params": {"vector": vectors[0]}  # Assuming single query vector
+                    }
+                }
+            }
+        }
+
+        result = self.client.search(
+            index=f"{self.index_prefix}_{index_name}",
+            body=query
+        )
+
+        return self._result_to_search_result(result)
+
+    def get_or_create_index(self, index_name: str, dimension: int):
+        if not self.has_index(index_name):
+            self._create_index(index_name, dimension)
+
+    def get(self, index_name: str) -> Optional[GetResult]:
+        query = {
+            "query": {"match_all": {}},
+            "_source": ["text", "metadata"]
+        }
+
+        result = self.client.search(index=f"{self.index_prefix}_{index_name}", body=query)
+        return self._result_to_get_result(result)
+
+    def insert(self, index_name: str, items: list[VectorItem]):
+        if not self.has_index(index_name):
+            self._create_index(index_name, dimension=len(items[0]["vector"]))
+
+        for batch in self._create_batches(items):
+            actions = [
+                {"index": {"_id": item["id"], "_source": {"vector": item["vector"], "text": item["text"], "metadata": item["metadata"]}}}
+                for item in batch
+            ]
+            self.client.bulk(actions)
+
+    def upsert(self, index_name: str, items: list[VectorItem]):
+        if not self.has_index(index_name):
+            self._create_index(index_name, dimension=len(items[0]["vector"]))
+
+        for batch in self._create_batches(items):
+            actions = [
+                {"index": {"_id": item["id"], "_source": {"vector": item["vector"], "text": item["text"], "metadata": item["metadata"]}}}
+                for item in batch
+            ]
+            self.client.bulk(actions)
+
+    def delete(self, index_name: str, ids: list[str]):
+        actions = [{"delete": {"_index": f"{self.index_prefix}_{index_name}", "_id": id}} for id in ids]
+        self.client.bulk(body=actions)
+
+    def reset(self):
+        indices = self.client.indices.get(index=f"{self.index_prefix}_*")
+        for index in indices:
+            self.client.indices.delete(index=index)

+ 7 - 0
backend/open_webui/config.py

@@ -957,6 +957,13 @@ MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db")
 # Qdrant
 # Qdrant
 QDRANT_URI = os.environ.get("QDRANT_URI", None)
 QDRANT_URI = os.environ.get("QDRANT_URI", None)
 
 
+# OpenSearch
+OPENSEARCH_URI = os.environ.get("OPENSEARCH_URI", "https://localhost:9200")
+OPENSEARCH_SSL = os.environ.get("OPENSEARCH_SSL", True)
+OPENSEARCH_CERT_VERIFY = os.environ.get("OPENSEARCH_CERT_VERIFY", False)
+OPENSEARCH_USERNAME = os.environ.get("OPENSEARCH_USERNAME", None)
+OPENSEARCH_PASSWORD = os.environ.get("OPENSEARCH_PASSWORD", None)
+
 ####################################
 ####################################
 # Information Retrieval (RAG)
 # Information Retrieval (RAG)
 ####################################
 ####################################

+ 1 - 0
backend/requirements.txt

@@ -43,6 +43,7 @@ fake-useragent==1.5.1
 chromadb==0.5.15
 chromadb==0.5.15
 pymilvus==2.4.9
 pymilvus==2.4.9
 qdrant-client~=1.12.0
 qdrant-client~=1.12.0
+opensearch-py==2.7.1
 
 
 sentence-transformers==3.2.0
 sentence-transformers==3.2.0
 colbert-ai==0.2.21
 colbert-ai==0.2.21

+ 1 - 0
pyproject.toml

@@ -49,6 +49,7 @@ dependencies = [
     "fake-useragent==1.5.1",
     "fake-useragent==1.5.1",
     "chromadb==0.5.9",
     "chromadb==0.5.9",
     "pymilvus==2.4.7",
     "pymilvus==2.4.7",
+    "opensearch-py==2.7.1",
 
 
     "sentence-transformers==3.2.0",
     "sentence-transformers==3.2.0",
     "colbert-ai==0.2.21",
     "colbert-ai==0.2.21",