ソースを参照

added elasticsearch support

ofek 2 ヶ月 前
コミット
737dfd2763

+ 9 - 0
backend/open_webui/config.py

@@ -1541,6 +1541,15 @@ OPENSEARCH_CERT_VERIFY = os.environ.get("OPENSEARCH_CERT_VERIFY", False)
 OPENSEARCH_USERNAME = os.environ.get("OPENSEARCH_USERNAME", None)
 OPENSEARCH_PASSWORD = os.environ.get("OPENSEARCH_PASSWORD", None)
 
+# ElasticSearch
+ELASTICSEARCH_URL = os.environ.get("ELASTICSEARCH_URL", "https://localhost:9200")
+ELASTICSEARCH_CA_CERTS = os.environ.get("ELASTICSEARCH_CA_CERTS", None)
+ELASTICSEARCH_API_KEY = os.environ.get("ELASTICSEARCH_API_KEY", None)
+ELASTICSEARCH_USERNAME = os.environ.get("ELASTICSEARCH_USERNAME", None)
+ELASTICSEARCH_PASSWORD = os.environ.get("ELASTICSEARCH_PASSWORD", None)
+ELASTICSEARCH_CLOUD_ID = os.environ.get("ELASTICSEARCH_CLOUD_ID", None)
+SSL_ASSERT_FINGERPRINT = os.environ.get("SSL_ASSERT_FINGERPRINT", None)
+
 # Pgvector
 PGVECTOR_DB_URL = os.environ.get("PGVECTOR_DB_URL", DATABASE_URL)
 if VECTOR_DB == "pgvector" and not PGVECTOR_DB_URL.startswith("postgres"):

+ 4 - 0
backend/open_webui/retrieval/vector/connector.py

@@ -16,6 +16,10 @@ elif VECTOR_DB == "pgvector":
     from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient
 
     VECTOR_DB_CLIENT = PgvectorClient()
+elif VECTOR_DB == "elasticsearch":
+    from open_webui.retrieval.vector.dbs.elasticsearch import ElasticsearchClient 
+
+    VECTOR_DB_CLIENT = ElasticsearchClient()
 else:
     from open_webui.retrieval.vector.dbs.chroma import ChromaClient
 

+ 283 - 0
backend/open_webui/retrieval/vector/dbs/elasticsearch.py

@@ -0,0 +1,283 @@
+from elasticsearch import Elasticsearch, BadRequestError
+from typing import Optional
+import ssl
+from elasticsearch.helpers import bulk,scan
+from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
+from open_webui.config import (
+    ELASTICSEARCH_URL,
+    ELASTICSEARCH_CA_CERTS, 
+    ELASTICSEARCH_API_KEY,
+    ELASTICSEARCH_USERNAME,
+    ELASTICSEARCH_PASSWORD, 
+    ELASTICSEARCH_CLOUD_ID,
+    SSL_ASSERT_FINGERPRINT
+)
+
+
+
+
+class ElasticsearchClient:
+    """
+    Important:
+    in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating 
+    an index for each file but store it as a text field, while seperating to different index 
+    baesd on the embedding length.
+    """
+    def __init__(self):
+        self.index_prefix = "open_webui_collections"
+        self.client = Elasticsearch(
+            hosts=[ELASTICSEARCH_URL],
+            ca_certs=ELASTICSEARCH_CA_CERTS,
+            api_key=ELASTICSEARCH_API_KEY,
+            cloud_id=ELASTICSEARCH_CLOUD_ID,
+            basic_auth=(ELASTICSEARCH_USERNAME,ELASTICSEARCH_PASSWORD) if ELASTICSEARCH_USERNAME and ELASTICSEARCH_PASSWORD else None,
+            ssl_assert_fingerprint=SSL_ASSERT_FINGERPRINT
+            
+        )
+    #Status: works
+    def _get_index_name(self,dimension:int)->str:
+        return f"{self.index_prefix}_d{str(dimension)}"
+    
+    #Status: works
+    def _scan_result_to_get_result(self, result) -> GetResult:
+        if not result:
+            return None
+        ids = []
+        documents = []
+        metadatas = []
+
+        for hit in result:
+            ids.append(hit["_id"])
+            documents.append(hit["_source"].get("text"))
+            metadatas.append(hit["_source"].get("metadata"))
+
+        return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
+
+    #Status: works
+    def _result_to_get_result(self, result) -> GetResult:
+        if not result["hits"]["hits"]:
+            return None
+        ids = []
+        documents = []
+        metadatas = []
+
+        for hit in result["hits"]["hits"]:
+            ids.append(hit["_id"])
+            documents.append(hit["_source"].get("text"))
+            metadatas.append(hit["_source"].get("metadata"))
+
+        return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
+
+    #Status: works
+    def _result_to_search_result(self, result) -> SearchResult:
+        ids = []
+        distances = []
+        documents = []
+        metadatas = []
+
+        for hit in result["hits"]["hits"]:
+            ids.append(hit["_id"])
+            distances.append(hit["_score"])
+            documents.append(hit["_source"].get("text"))
+            metadatas.append(hit["_source"].get("metadata"))
+
+        return SearchResult(
+            ids=[ids], distances=[distances], documents=[documents], metadatas=[metadatas]
+        )
+    #Status: works
+    def _create_index(self, dimension: int):
+        body = {
+            "mappings": {
+                "properties": {
+                    "collection": {"type": "keyword"},
+                    "id": {"type": "keyword"},
+                    "vector": {
+                        "type": "dense_vector",
+                        "dims": dimension,  # Adjust based on your vector dimensions
+                        "index": True,
+                        "similarity": "cosine",
+                    },
+                    "text": {"type": "text"},
+                    "metadata": {"type": "object"},
+                }
+            }
+        }
+        self.client.indices.create(index=self._get_index_name(dimension), body=body)
+    #Status: works
+
+    def _create_batches(self, items: list[VectorItem], batch_size=100):
+        for i in range(0, len(items), batch_size):
+            yield items[i : min(i + batch_size,len(items))]
+
+    #Status: works
+    def has_collection(self,collection_name) -> bool:
+        query_body = {"query": {"bool": {"filter": []}}}
+        query_body["query"]["bool"]["filter"].append({"term": {"collection": collection_name}})
+
+        try:
+            result = self.client.count(
+                index=f"{self.index_prefix}*",
+                body=query_body
+            )
+            
+            return result.body["count"]>0
+        except Exception as e:
+            return None
+        
+
+        
+    #@TODO: Make this delete a collection and not an index
+    def delete_colleciton(self, collection_name: str):
+        # TODO: fix this to include the dimension or a * prefix
+        # delete_collection here means delete a bunch of documents for an index.
+        # We are simply adapting to the norms of the other DBs.
+        self.client.indices.delete(index=self._get_collection_name(collection_name))
+    #Status: works
+    def search(
+        self, collection_name: str, vectors: list[list[float]], limit: int
+    ) -> Optional[SearchResult]:
+        query = {
+            "size": limit,
+            "_source": [
+                "text",
+                "metadata"
+            ],
+            "query": {
+                "script_score": {
+                    "query": {
+                        "bool": {
+                            "filter": [
+                                {
+                                    "term": {
+                                        "collection": collection_name
+                                    }
+                                }
+                            ]
+                        }
+                    },
+                    "script": {
+                        "source": "cosineSimilarity(params.vector, 'vector') + 1.0",
+                        "params": {
+                            "vector": vectors[0]
+                        }, # Assuming single query vector
+                    },
+                }
+            },
+        }
+
+        result = self.client.search(
+            index=self._get_index_name(len(vectors[0])), body=query
+        )
+
+        return self._result_to_search_result(result)
+    #Status: only tested halfwat
+    def query(
+        self, collection_name: str, filter: dict, limit: Optional[int] = None
+    ) -> Optional[GetResult]:
+        if not self.has_collection(collection_name):
+            return None
+
+        query_body = {
+            "query": {"bool": {"filter": []}},
+            "_source": ["text", "metadata"],
+        }
+
+        for field, value in filter.items():
+            query_body["query"]["bool"]["filter"].append({"term": {field: value}})
+        query_body["query"]["bool"]["filter"].append({"term": {"collection": collection_name}})
+        size = limit if limit else 10
+
+        try:
+            result = self.client.search(
+                index=f"{self.index_prefix}*",
+                body=query_body,
+                size=size,
+            )
+            
+            return self._result_to_get_result(result)
+
+        except Exception as e:
+            return None
+    #Status: works
+    def _has_index(self,dimension:int):
+        return self.client.indices.exists(index=self._get_index_name(dimension=dimension))
+
+
+    def get_or_create_index(self, dimension: int):
+        if not self._has_index(dimension=dimension):
+            self._create_index(dimension=dimension)
+    #Status: works
+    def get(self, collection_name: str) -> Optional[GetResult]:
+        # Get all the items in the collection.
+        query = {
+                    "query": {
+                        "bool": {
+                            "filter": [
+                                {
+                                    "term": {
+                                        "collection": collection_name
+                                    }
+                                }
+                            ]
+                        }
+                    }, "_source": ["text", "metadata"]}
+        results = list(scan(self.client, index=f"{self.index_prefix}*", query=query))
+        
+        return self._scan_result_to_get_result(results)
+
+    #Status: works
+    def insert(self, collection_name: str, items: list[VectorItem]):
+        if not self._has_index(dimension=len(items[0]["vector"])):
+            self._create_index(dimension=len(items[0]["vector"]))
+
+
+        for batch in self._create_batches(items):
+            actions = [
+                {
+                        "_index":self._get_index_name(dimension=len(items[0]["vector"])),
+                        "_id": item["id"],
+                        "_source": {
+                            "collection": collection_name,
+                            "vector": item["vector"],
+                            "text": item["text"],
+                            "metadata": item["metadata"],
+                        },
+                    }
+                for item in batch
+            ]
+            bulk(self.client,actions)
+    # Status: should work
+    def upsert(self, collection_name: str, items: list[VectorItem]):
+        if not self._has_index(dimension=len(items[0]["vector"])):
+            self._create_index(collection_name, dimension=len(items[0]["vector"]))
+
+        for batch in self._create_batches(items):
+            actions = [
+                {
+                        "_index": self._get_index_name(dimension=len(items[0]["vector"])),
+                        "_id": item["id"],
+                        "_source": {
+                            "vector": item["vector"],
+                            "text": item["text"],
+                            "metadata": item["metadata"],
+                        },
+                    
+                }
+                for item in batch
+            ]
+            self.client.bulk(actions)
+
+    #TODO: This currently deletes by * which is not always supported in ElasticSearch. 
+    # Need to read a bit before changing. Also, need to delete from a specific collection
+    def delete(self, collection_name: str, ids: list[str]):
+        #Assuming ID is unique across collections and indexes
+        actions = [
+            {"delete": {"_index": f"{self.index_prefix}*", "_id": id}}
+            for id in ids
+        ]
+        self.client.bulk(body=actions)
+
+    def reset(self):
+        indices = self.client.indices.get(index=f"{self.index_prefix}*")
+        for index in indices:
+            self.client.indices.delete(index=index)

+ 2 - 0
backend/requirements.txt

@@ -49,6 +49,8 @@ pymilvus==2.5.0
 qdrant-client~=1.12.0
 opensearch-py==2.8.0
 playwright==1.49.1 # Caution: version must match docker-compose.playwright.yaml
+elasticsearch==8.17.1
+
 
 transformers
 sentence-transformers==3.3.1