浏览代码

fixed es bugs

ofek 1 月之前
父节点
当前提交
a8f205213c
共有 2 个文件被更改,包括 135 次插入93 次删除
  1. 1 1
      backend/open_webui/config.py
  2. 134 92
      backend/open_webui/retrieval/vector/dbs/elasticsearch.py

+ 1 - 1
backend/open_webui/config.py

@@ -1553,7 +1553,7 @@ ELASTICSEARCH_USERNAME = os.environ.get("ELASTICSEARCH_USERNAME", None)
 ELASTICSEARCH_PASSWORD = os.environ.get("ELASTICSEARCH_PASSWORD", None)
 ELASTICSEARCH_CLOUD_ID = os.environ.get("ELASTICSEARCH_CLOUD_ID", None)
 SSL_ASSERT_FINGERPRINT = os.environ.get("SSL_ASSERT_FINGERPRINT", None)
-
+ELASTICSEARCH_INDEX_PREFIX = os.environ.get("ELASTICSEARCH_INDEX_PREFIX", "open_webui_collections")
 # Pgvector
 PGVECTOR_DB_URL = os.environ.get("PGVECTOR_DB_URL", DATABASE_URL)
 if VECTOR_DB == "pgvector" and not PGVECTOR_DB_URL.startswith("postgres"):

+ 134 - 92
backend/open_webui/retrieval/vector/dbs/elasticsearch.py

@@ -1,47 +1,46 @@
 from elasticsearch import Elasticsearch, BadRequestError
 from typing import Optional
 import ssl
-from elasticsearch.helpers import bulk, scan
+from elasticsearch.helpers import bulk,scan
 from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
 from open_webui.config import (
     ELASTICSEARCH_URL,
-    ELASTICSEARCH_CA_CERTS,
+    ELASTICSEARCH_CA_CERTS, 
     ELASTICSEARCH_API_KEY,
     ELASTICSEARCH_USERNAME,
-    ELASTICSEARCH_PASSWORD,
+    ELASTICSEARCH_PASSWORD, 
     ELASTICSEARCH_CLOUD_ID,
+    ELASTICSEARCH_INDEX_PREFIX,
     SSL_ASSERT_FINGERPRINT,
+    
 )
 
 
+
+
 class ElasticsearchClient:
     """
     Important:
-    in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating
-    an index for each file but store it as a text field, while seperating to different index
+    in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating 
+    an index for each file but store it as a text field, while seperating to different index 
     baesd on the embedding length.
     """
-
     def __init__(self):
-        self.index_prefix = "open_webui_collections"
+        self.index_prefix = ELASTICSEARCH_INDEX_PREFIX
         self.client = Elasticsearch(
             hosts=[ELASTICSEARCH_URL],
             ca_certs=ELASTICSEARCH_CA_CERTS,
             api_key=ELASTICSEARCH_API_KEY,
             cloud_id=ELASTICSEARCH_CLOUD_ID,
-            basic_auth=(
-                (ELASTICSEARCH_USERNAME, ELASTICSEARCH_PASSWORD)
-                if ELASTICSEARCH_USERNAME and ELASTICSEARCH_PASSWORD
-                else None
-            ),
-            ssl_assert_fingerprint=SSL_ASSERT_FINGERPRINT,
+            basic_auth=(ELASTICSEARCH_USERNAME,ELASTICSEARCH_PASSWORD) if ELASTICSEARCH_USERNAME and ELASTICSEARCH_PASSWORD else None,
+            ssl_assert_fingerprint=SSL_ASSERT_FINGERPRINT
+            
         )
-
-    # Status: works
-    def _get_index_name(self, dimension: int) -> str:
+    #Status: works
+    def _get_index_name(self,dimension:int)->str:
         return f"{self.index_prefix}_d{str(dimension)}"
-
-    # Status: works
+    
+    #Status: works
     def _scan_result_to_get_result(self, result) -> GetResult:
         if not result:
             return None
@@ -56,7 +55,7 @@ class ElasticsearchClient:
 
         return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
 
-    # Status: works
+    #Status: works
     def _result_to_get_result(self, result) -> GetResult:
         if not result["hits"]["hits"]:
             return None
@@ -71,7 +70,7 @@ class ElasticsearchClient:
 
         return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
 
-    # Status: works
+    #Status: works
     def _result_to_search_result(self, result) -> SearchResult:
         ids = []
         distances = []
@@ -85,16 +84,22 @@ class ElasticsearchClient:
             metadatas.append(hit["_source"].get("metadata"))
 
         return SearchResult(
-            ids=[ids],
-            distances=[distances],
-            documents=[documents],
-            metadatas=[metadatas],
+            ids=[ids], distances=[distances], documents=[documents], metadatas=[metadatas]
         )
-
-    # Status: works
+    #Status: works
     def _create_index(self, dimension: int):
         body = {
             "mappings": {
+                "dynamic_templates": [
+                    {
+                    "strings": {
+                    "match_mapping_type": "string",
+                    "mapping": {
+                        "type": "keyword"
+                         }
+                        }
+                    }
+                ],
                 "properties": {
                     "collection": {"type": "keyword"},
                     "id": {"type": "keyword"},
@@ -110,51 +115,64 @@ class ElasticsearchClient:
             }
         }
         self.client.indices.create(index=self._get_index_name(dimension), body=body)
-
-    # Status: works
+    #Status: works
 
     def _create_batches(self, items: list[VectorItem], batch_size=100):
         for i in range(0, len(items), batch_size):
-            yield items[i : min(i + batch_size, len(items))]
+            yield items[i : min(i + batch_size,len(items))]
 
-    # Status: works
-    def has_collection(self, collection_name) -> bool:
+    #Status: works
+    def has_collection(self,collection_name) -> bool:
         query_body = {"query": {"bool": {"filter": []}}}
-        query_body["query"]["bool"]["filter"].append(
-            {"term": {"collection": collection_name}}
-        )
+        query_body["query"]["bool"]["filter"].append({"term": {"collection": collection_name}})
 
         try:
-            result = self.client.count(index=f"{self.index_prefix}*", body=query_body)
-
-            return result.body["count"] > 0
+            result = self.client.count(
+                index=f"{self.index_prefix}*",
+                body=query_body
+            )
+            
+            return result.body["count"]>0
         except Exception as e:
             return None
+        
 
-    # @TODO: Make this delete a collection and not an index
-    def delete_colleciton(self, collection_name: str):
-        # TODO: fix this to include the dimension or a * prefix
-        # delete_collection here means delete a bunch of documents for an index.
-        # We are simply adapting to the norms of the other DBs.
-        self.client.indices.delete(index=self._get_collection_name(collection_name))
-
-    # Status: works
+        
+    def delete_collection(self, collection_name: str):
+        query = {
+            "query": {
+                "term": {"collection": collection_name}
+            }
+        }
+        self.client.delete_by_query(index=f"{self.index_prefix}*", body=query)
+    #Status: works
     def search(
         self, collection_name: str, vectors: list[list[float]], limit: int
     ) -> Optional[SearchResult]:
         query = {
             "size": limit,
-            "_source": ["text", "metadata"],
+            "_source": [
+                "text",
+                "metadata"
+            ],
             "query": {
                 "script_score": {
                     "query": {
-                        "bool": {"filter": [{"term": {"collection": collection_name}}]}
+                        "bool": {
+                            "filter": [
+                                {
+                                    "term": {
+                                        "collection": collection_name
+                                    }
+                                }
+                            ]
+                        }
                     },
                     "script": {
                         "source": "cosineSimilarity(params.vector, 'vector') + 1.0",
                         "params": {
                             "vector": vectors[0]
-                        },  # Assuming single query vector
+                        }, # Assuming single query vector
                     },
                 }
             },
@@ -165,8 +183,7 @@ class ElasticsearchClient:
         )
 
         return self._result_to_search_result(result)
-
-    # Status: only tested halfwat
+    #Status: only tested halfwat
     def query(
         self, collection_name: str, filter: dict, limit: Optional[int] = None
     ) -> Optional[GetResult]:
@@ -180,9 +197,7 @@ class ElasticsearchClient:
 
         for field, value in filter.items():
             query_body["query"]["bool"]["filter"].append({"term": {field: value}})
-        query_body["query"]["bool"]["filter"].append(
-            {"term": {"collection": collection_name}}
-        )
+        query_body["query"]["bool"]["filter"].append({"term": {"collection": collection_name}})
         size = limit if limit else 10
 
         try:
@@ -191,82 +206,109 @@ class ElasticsearchClient:
                 body=query_body,
                 size=size,
             )
-
+            
             return self._result_to_get_result(result)
 
         except Exception as e:
             return None
+    #Status: works
+    def _has_index(self,dimension:int):
+        return self.client.indices.exists(index=self._get_index_name(dimension=dimension))
 
-    # Status: works
-    def _has_index(self, dimension: int):
-        return self.client.indices.exists(
-            index=self._get_index_name(dimension=dimension)
-        )
 
     def get_or_create_index(self, dimension: int):
         if not self._has_index(dimension=dimension):
             self._create_index(dimension=dimension)
-
-    # Status: works
+    #Status: works
     def get(self, collection_name: str) -> Optional[GetResult]:
         # Get all the items in the collection.
         query = {
-            "query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}},
-            "_source": ["text", "metadata"],
-        }
+                    "query": {
+                        "bool": {
+                            "filter": [
+                                {
+                                    "term": {
+                                        "collection": collection_name
+                                    }
+                                }
+                            ]
+                        }
+                    }, "_source": ["text", "metadata"]}
         results = list(scan(self.client, index=f"{self.index_prefix}*", query=query))
-
+        
         return self._scan_result_to_get_result(results)
 
-    # Status: works
+    #Status: works
     def insert(self, collection_name: str, items: list[VectorItem]):
         if not self._has_index(dimension=len(items[0]["vector"])):
             self._create_index(dimension=len(items[0]["vector"]))
 
+
         for batch in self._create_batches(items):
             actions = [
-                {
-                    "_index": self._get_index_name(dimension=len(items[0]["vector"])),
-                    "_id": item["id"],
-                    "_source": {
-                        "collection": collection_name,
-                        "vector": item["vector"],
-                        "text": item["text"],
-                        "metadata": item["metadata"],
-                    },
-                }
+                    {
+                        "_index":self._get_index_name(dimension=len(items[0]["vector"])),
+                        "_id": item["id"],
+                        "_source": {
+                            "collection": collection_name,
+                            "vector": item["vector"],
+                            "text": item["text"],
+                            "metadata": item["metadata"],
+                        },
+                    }
                 for item in batch
             ]
-            bulk(self.client, actions)
+            bulk(self.client,actions)
 
-    # Status: should work
+    # Upsert documents using the update API with doc_as_upsert=True.
     def upsert(self, collection_name: str, items: list[VectorItem]):
         if not self._has_index(dimension=len(items[0]["vector"])):
-            self._create_index(collection_name, dimension=len(items[0]["vector"]))
-
+            self._create_index(dimension=len(items[0]["vector"]))
         for batch in self._create_batches(items):
             actions = [
                 {
-                    "_index": self._get_index_name(dimension=len(items[0]["vector"])),
+                    "_op_type": "update",
+                    "_index": self._get_index_name(dimension=len(item["vector"])),
                     "_id": item["id"],
-                    "_source": {
+                    "doc": {
+                        "collection": collection_name,
                         "vector": item["vector"],
                         "text": item["text"],
                         "metadata": item["metadata"],
                     },
+                    "doc_as_upsert": True,
                 }
                 for item in batch
             ]
-            self.client.bulk(actions)
-
-    # TODO: This currently deletes by * which is not always supported in ElasticSearch.
-    # Need to read a bit before changing. Also, need to delete from a specific collection
-    def delete(self, collection_name: str, ids: list[str]):
-        # Assuming ID is unique across collections and indexes
-        actions = [
-            {"delete": {"_index": f"{self.index_prefix}*", "_id": id}} for id in ids
-        ]
-        self.client.bulk(body=actions)
+            bulk(self.client,actions)
+
+
+    # Delete specific documents from a collection by filtering on both collection and document IDs.
+    def delete(
+        self,
+        collection_name: str,
+        ids: Optional[list[str]] = None,
+        filter: Optional[dict] = None,
+    ):
+
+        query = {
+            "query": {
+                "bool": {
+                    "filter": [
+                        {"term": {"collection": collection_name}}
+                    ]
+                }
+            }
+        }
+                #logic based on chromaDB
+        if ids:
+            query["query"]["bool"]["filter"].append({"terms": {"_id": ids}})
+        elif filter:
+            for field, value in filter.items():
+                query["query"]["bool"]["filter"].append({"term": {f"metadata.{field}": value}})
+
+
+        self.client.delete_by_query(index=f"{self.index_prefix}*", body=query)
 
     def reset(self):
         indices = self.client.indices.get(index=f"{self.index_prefix}*")