Browse Source

chore: format

Timothy Jaeryang Baek 1 month ago
parent
commit
d4fca9dabf
1 changed files with 81 additions and 102 deletions
  1. 81 102
      backend/open_webui/retrieval/vector/dbs/elasticsearch.py

+ 81 - 102
backend/open_webui/retrieval/vector/dbs/elasticsearch.py

@@ -1,30 +1,28 @@
 from elasticsearch import Elasticsearch, BadRequestError
 from typing import Optional
 import ssl
-from elasticsearch.helpers import bulk,scan
+from elasticsearch.helpers import bulk, scan
 from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
 from open_webui.config import (
     ELASTICSEARCH_URL,
-    ELASTICSEARCH_CA_CERTS, 
+    ELASTICSEARCH_CA_CERTS,
     ELASTICSEARCH_API_KEY,
     ELASTICSEARCH_USERNAME,
-    ELASTICSEARCH_PASSWORD, 
+    ELASTICSEARCH_PASSWORD,
     ELASTICSEARCH_CLOUD_ID,
     ELASTICSEARCH_INDEX_PREFIX,
     SSL_ASSERT_FINGERPRINT,
-    
 )
 
 
-
-
 class ElasticsearchClient:
     """
     Important:
-    in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating 
-    an index for each file but store it as a text field, while seperating to different index 
+    in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating
+    an index for each file but store it as a text field, while seperating to different index
     baesd on the embedding length.
     """
+
     def __init__(self):
         self.index_prefix = ELASTICSEARCH_INDEX_PREFIX
         self.client = Elasticsearch(
@@ -32,15 +30,19 @@ class ElasticsearchClient:
             ca_certs=ELASTICSEARCH_CA_CERTS,
             api_key=ELASTICSEARCH_API_KEY,
             cloud_id=ELASTICSEARCH_CLOUD_ID,
-            basic_auth=(ELASTICSEARCH_USERNAME,ELASTICSEARCH_PASSWORD) if ELASTICSEARCH_USERNAME and ELASTICSEARCH_PASSWORD else None,
-            ssl_assert_fingerprint=SSL_ASSERT_FINGERPRINT
-            
+            basic_auth=(
+                (ELASTICSEARCH_USERNAME, ELASTICSEARCH_PASSWORD)
+                if ELASTICSEARCH_USERNAME and ELASTICSEARCH_PASSWORD
+                else None
+            ),
+            ssl_assert_fingerprint=SSL_ASSERT_FINGERPRINT,
         )
-    #Status: works
-    def _get_index_name(self,dimension:int)->str:
+
+    # Status: works
+    def _get_index_name(self, dimension: int) -> str:
         return f"{self.index_prefix}_d{str(dimension)}"
-    
-    #Status: works
+
+    # Status: works
     def _scan_result_to_get_result(self, result) -> GetResult:
         if not result:
             return None
@@ -55,7 +57,7 @@ class ElasticsearchClient:
 
         return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
 
-    #Status: works
+    # Status: works
     def _result_to_get_result(self, result) -> GetResult:
         if not result["hits"]["hits"]:
             return None
@@ -70,7 +72,7 @@ class ElasticsearchClient:
 
         return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
 
-    #Status: works
+    # Status: works
     def _result_to_search_result(self, result) -> SearchResult:
         ids = []
         distances = []
@@ -84,19 +86,21 @@ class ElasticsearchClient:
             metadatas.append(hit["_source"].get("metadata"))
 
         return SearchResult(
-            ids=[ids], distances=[distances], documents=[documents], metadatas=[metadatas]
+            ids=[ids],
+            distances=[distances],
+            documents=[documents],
+            metadatas=[metadatas],
         )
-    #Status: works
+
+    # Status: works
     def _create_index(self, dimension: int):
         body = {
             "mappings": {
                 "dynamic_templates": [
                     {
-                    "strings": {
-                    "match_mapping_type": "string",
-                    "mapping": {
-                        "type": "keyword"
-                         }
+                        "strings": {
+                            "match_mapping_type": "string",
+                            "mapping": {"type": "keyword"},
                         }
                     }
                 ],
@@ -111,68 +115,52 @@ class ElasticsearchClient:
                     },
                     "text": {"type": "text"},
                     "metadata": {"type": "object"},
-                }
+                },
             }
         }
         self.client.indices.create(index=self._get_index_name(dimension), body=body)
-    #Status: works
+
+    # Status: works
 
     def _create_batches(self, items: list[VectorItem], batch_size=100):
         for i in range(0, len(items), batch_size):
-            yield items[i : min(i + batch_size,len(items))]
+            yield items[i : min(i + batch_size, len(items))]
 
-    #Status: works
-    def has_collection(self,collection_name) -> bool:
+    # Status: works
+    def has_collection(self, collection_name) -> bool:
         query_body = {"query": {"bool": {"filter": []}}}
-        query_body["query"]["bool"]["filter"].append({"term": {"collection": collection_name}})
+        query_body["query"]["bool"]["filter"].append(
+            {"term": {"collection": collection_name}}
+        )
 
         try:
-            result = self.client.count(
-                index=f"{self.index_prefix}*",
-                body=query_body
-            )
-            
-            return result.body["count"]>0
+            result = self.client.count(index=f"{self.index_prefix}*", body=query_body)
+
+            return result.body["count"] > 0
         except Exception as e:
             return None
-        
 
-        
     def delete_collection(self, collection_name: str):
-        query = {
-            "query": {
-                "term": {"collection": collection_name}
-            }
-        }
+        query = {"query": {"term": {"collection": collection_name}}}
         self.client.delete_by_query(index=f"{self.index_prefix}*", body=query)
-    #Status: works
+
+    # Status: works
     def search(
         self, collection_name: str, vectors: list[list[float]], limit: int
     ) -> Optional[SearchResult]:
         query = {
             "size": limit,
-            "_source": [
-                "text",
-                "metadata"
-            ],
+            "_source": ["text", "metadata"],
             "query": {
                 "script_score": {
                     "query": {
-                        "bool": {
-                            "filter": [
-                                {
-                                    "term": {
-                                        "collection": collection_name
-                                    }
-                                }
-                            ]
-                        }
+                        "bool": {"filter": [{"term": {"collection": collection_name}}]}
                     },
                     "script": {
                         "source": "cosineSimilarity(params.vector, 'vector') + 1.0",
                         "params": {
                             "vector": vectors[0]
-                        }, # Assuming single query vector
+                        },  # Assuming single query vector
                     },
                 }
             },
@@ -183,7 +171,8 @@ class ElasticsearchClient:
         )
 
         return self._result_to_search_result(result)
-    #Status: only tested halfwat
+
+    # Status: only tested halfwat
     def query(
         self, collection_name: str, filter: dict, limit: Optional[int] = None
     ) -> Optional[GetResult]:
@@ -197,7 +186,9 @@ class ElasticsearchClient:
 
         for field, value in filter.items():
             query_body["query"]["bool"]["filter"].append({"term": {field: value}})
-        query_body["query"]["bool"]["filter"].append({"term": {"collection": collection_name}})
+        query_body["query"]["bool"]["filter"].append(
+            {"term": {"collection": collection_name}}
+        )
         size = limit if limit else 10
 
         try:
@@ -206,59 +197,53 @@ class ElasticsearchClient:
                 body=query_body,
                 size=size,
             )
-            
+
             return self._result_to_get_result(result)
 
         except Exception as e:
             return None
-    #Status: works
-    def _has_index(self,dimension:int):
-        return self.client.indices.exists(index=self._get_index_name(dimension=dimension))
 
+    # Status: works
+    def _has_index(self, dimension: int):
+        return self.client.indices.exists(
+            index=self._get_index_name(dimension=dimension)
+        )
 
     def get_or_create_index(self, dimension: int):
         if not self._has_index(dimension=dimension):
             self._create_index(dimension=dimension)
-    #Status: works
+
+    # Status: works
     def get(self, collection_name: str) -> Optional[GetResult]:
         # Get all the items in the collection.
         query = {
-                    "query": {
-                        "bool": {
-                            "filter": [
-                                {
-                                    "term": {
-                                        "collection": collection_name
-                                    }
-                                }
-                            ]
-                        }
-                    }, "_source": ["text", "metadata"]}
+            "query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}},
+            "_source": ["text", "metadata"],
+        }
         results = list(scan(self.client, index=f"{self.index_prefix}*", query=query))
-        
+
         return self._scan_result_to_get_result(results)
 
-    #Status: works
+    # Status: works
     def insert(self, collection_name: str, items: list[VectorItem]):
         if not self._has_index(dimension=len(items[0]["vector"])):
             self._create_index(dimension=len(items[0]["vector"]))
 
-
         for batch in self._create_batches(items):
             actions = [
-                    {
-                        "_index":self._get_index_name(dimension=len(items[0]["vector"])),
-                        "_id": item["id"],
-                        "_source": {
-                            "collection": collection_name,
-                            "vector": item["vector"],
-                            "text": item["text"],
-                            "metadata": item["metadata"],
-                        },
-                    }
+                {
+                    "_index": self._get_index_name(dimension=len(items[0]["vector"])),
+                    "_id": item["id"],
+                    "_source": {
+                        "collection": collection_name,
+                        "vector": item["vector"],
+                        "text": item["text"],
+                        "metadata": item["metadata"],
+                    },
+                }
                 for item in batch
             ]
-            bulk(self.client,actions)
+            bulk(self.client, actions)
 
     # Upsert documents using the update API with doc_as_upsert=True.
     def upsert(self, collection_name: str, items: list[VectorItem]):
@@ -280,8 +265,7 @@ class ElasticsearchClient:
                 }
                 for item in batch
             ]
-            bulk(self.client,actions)
-
+            bulk(self.client, actions)
 
     # Delete specific documents from a collection by filtering on both collection and document IDs.
     def delete(
@@ -292,21 +276,16 @@ class ElasticsearchClient:
     ):
 
         query = {
-            "query": {
-                "bool": {
-                    "filter": [
-                        {"term": {"collection": collection_name}}
-                    ]
-                }
-            }
+            "query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}}
         }
-                #logic based on chromaDB
+        # logic based on chromaDB
         if ids:
             query["query"]["bool"]["filter"].append({"terms": {"_id": ids}})
         elif filter:
             for field, value in filter.items():
-                query["query"]["bool"]["filter"].append({"term": {f"metadata.{field}": value}})
-
+                query["query"]["bool"]["filter"].append(
+                    {"term": {f"metadata.{field}": value}}
+                )
 
         self.client.delete_by_query(index=f"{self.index_prefix}*", body=query)