Ver código fonte

fix: opensearch vector db query structures, result mapping, filters, bulk query actions, knn_vector usage

Katharina 2 meses atrás
pai
commit
6cb0c0339a
1 arquivos alterados com 120 adições e 64 exclusões
  1. 120 64
      backend/open_webui/retrieval/vector/dbs/opensearch.py

+ 120 - 64
backend/open_webui/retrieval/vector/dbs/opensearch.py

@@ -1,4 +1,5 @@
 from opensearchpy import OpenSearch
+from opensearchpy.helpers import bulk
 from typing import Optional
 
 from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
@@ -20,8 +21,14 @@ class OpenSearchClient:
             verify_certs=OPENSEARCH_CERT_VERIFY,
             http_auth=(OPENSEARCH_USERNAME, OPENSEARCH_PASSWORD),
         )
+    
+    def _get_index_name(self, collection_name: str) -> str:
+        return f"{self.index_prefix}_{collection_name}"
 
     def _result_to_get_result(self, result) -> GetResult:
+        if not result["hits"]["hits"]:
+            return None
+        
         ids = []
         documents = []
         metadatas = []
@@ -31,9 +38,12 @@ class OpenSearchClient:
             documents.append(hit["_source"].get("text"))
             metadatas.append(hit["_source"].get("metadata"))
 
-        return GetResult(ids=ids, documents=documents, metadatas=metadatas)
+        return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
 
     def _result_to_search_result(self, result) -> SearchResult:
+        if not result["hits"]["hits"]:
+            return None
+        
         ids = []
         distances = []
         documents = []
@@ -46,25 +56,32 @@ class OpenSearchClient:
             metadatas.append(hit["_source"].get("metadata"))
 
         return SearchResult(
-            ids=ids, distances=distances, documents=documents, metadatas=metadatas
+            ids=[ids], distances=[distances], documents=[documents], metadatas=[metadatas]
         )
 
     def _create_index(self, collection_name: str, dimension: int):
         body = {
+            "settings": {
+                "index": {
+                "knn": True
+                }
+            },
             "mappings": {
                 "properties": {
                     "id": {"type": "keyword"},
                     "vector": {
-                        "type": "dense_vector",
-                        "dims": dimension,  # Adjust based on your vector dimensions
-                        "index": true,
+                        "type": "knn_vector",
+                        "dimension": dimension,  # Adjust based on your vector dimensions
+                        "index": True,
                         "similarity": "faiss",
                         "method": {
                             "name": "hnsw",
-                            "space_type": "ip",  # Use inner product to approximate cosine similarity
+                            "space_type": "innerproduct",  # Use inner product to approximate cosine similarity
                             "engine": "faiss",
-                            "ef_construction": 128,
-                            "m": 16,
+                            "parameters": {
+                                "ef_construction": 128,
+                                "m": 16,
+                            }
                         },
                     },
                     "text": {"type": "text"},
@@ -73,7 +90,7 @@ class OpenSearchClient:
             }
         }
         self.client.indices.create(
-            index=f"{self.index_prefix}_{collection_name}", body=body
+            index=self._get_index_name(collection_name), body=body
         )
 
     def _create_batches(self, items: list[VectorItem], batch_size=100):
@@ -84,38 +101,49 @@ class OpenSearchClient:
         # has_collection here means has index.
         # We are simply adapting to the norms of the other DBs.
         return self.client.indices.exists(
-            index=f"{self.index_prefix}_{collection_name}"
+            index=self._get_index_name(collection_name)
         )
 
-    def delete_colleciton(self, collection_name: str):
+    def delete_collection(self, collection_name: str):
         # delete_collection here means delete index.
         # We are simply adapting to the norms of the other DBs.
-        self.client.indices.delete(index=f"{self.index_prefix}_{collection_name}")
+        self.client.indices.delete(index=self._get_index_name(collection_name))
 
     def search(
-        self, collection_name: str, vectors: list[list[float]], limit: int
+        self, collection_name: str, vectors: list[list[float | int]], limit: int
     ) -> Optional[SearchResult]:
-        query = {
-            "size": limit,
-            "_source": ["text", "metadata"],
-            "query": {
-                "script_score": {
-                    "query": {"match_all": {}},
-                    "script": {
-                        "source": "cosineSimilarity(params.vector, 'vector') + 1.0",
-                        "params": {
-                            "vector": vectors[0]
-                        },  # Assuming single query vector
-                    },
-                }
-            },
-        }
-
-        result = self.client.search(
-            index=f"{self.index_prefix}_{collection_name}", body=query
-        )
+        try:
+            if not self.has_collection(collection_name):
+                return None
+            
+            query = {
+                "size": limit,
+                "_source": ["text", "metadata"],
+                "query": {
+                    "script_score": {
+                        "query": {
+                            "match_all": {}
+                        },
+                        "script": {
+                            "source": "cosineSimilarity(params.query_value, doc[params.field]) + 1.0",
+                            "params": {
+                            "field": "vector",
+                            "query_value": vectors[0]
+                            },  # Assuming single query vector
+                        },
+                    }
+                },
+            }
+            
+            result = self.client.search(
+                index=self._get_index_name(collection_name),
+                body=query
+            )
 
-        return self._result_to_search_result(result)
+            return self._result_to_search_result(result)
+        
+        except Exception as e:
+            return None
 
     def query(
         self, collection_name: str, filter: dict, limit: Optional[int] = None
@@ -124,18 +152,26 @@ class OpenSearchClient:
             return None
 
         query_body = {
-            "query": {"bool": {"filter": []}},
+            "query": {
+                "bool": {
+                    "filter": []
+                }
+            },
             "_source": ["text", "metadata"],
         }
 
         for field, value in filter.items():
-            query_body["query"]["bool"]["filter"].append({"term": {field: value}})
+            query_body["query"]["bool"]["filter"].append({
+                "match": {
+                    "metadata." + str(field): value
+                }
+            })
 
         size = limit if limit else 10
 
         try:
             result = self.client.search(
-                index=f"{self.index_prefix}_{collection_name}",
+                index=self._get_index_name(collection_name),
                 body=query_body,
                 size=size,
             )
@@ -146,14 +182,14 @@ class OpenSearchClient:
             return None
 
     def _create_index_if_not_exists(self, collection_name: str, dimension: int):
-        if not self.has_index(collection_name):
+        if not self.has_collection(collection_name):
             self._create_index(collection_name, dimension)
 
     def get(self, collection_name: str) -> Optional[GetResult]:
         query = {"query": {"match_all": {}}, "_source": ["text", "metadata"]}
 
         result = self.client.search(
-            index=f"{self.index_prefix}_{collection_name}", body=query
+            index=self._get_index_name(collection_name), body=query
         )
         return self._result_to_get_result(result)
 
@@ -165,18 +201,18 @@ class OpenSearchClient:
         for batch in self._create_batches(items):
             actions = [
                 {
-                    "index": {
-                        "_id": item["id"],
-                        "_source": {
-                            "vector": item["vector"],
-                            "text": item["text"],
-                            "metadata": item["metadata"],
-                        },
-                    }
+                    "_op_type": "index", 
+                    "_index": self._get_index_name(collection_name),
+                    "_id": item["id"], 
+                    "_source": {
+                        "vector": item["vector"],
+                        "text": item["text"],
+                        "metadata": item["metadata"],
+                    },
                 }
                 for item in batch
             ]
-            self.client.bulk(actions)
+            bulk(self.client, actions)
 
     def upsert(self, collection_name: str, items: list[VectorItem]):
         self._create_index_if_not_exists(
@@ -186,27 +222,47 @@ class OpenSearchClient:
         for batch in self._create_batches(items):
             actions = [
                 {
-                    "index": {
-                        "_id": item["id"],
-                        "_index": f"{self.index_prefix}_{collection_name}",
-                        "_source": {
-                            "vector": item["vector"],
-                            "text": item["text"],
-                            "metadata": item["metadata"],
-                        },
-                    }
+                    "_op_type": "update", 
+                    "_index": self._get_index_name(collection_name),
+                    "_id": item["id"], 
+                    "doc": {
+                        "vector": item["vector"],
+                        "text": item["text"],
+                        "metadata": item["metadata"],
+                    },
+                    "doc_as_upsert": True,
                 }
                 for item in batch
             ]
-            self.client.bulk(actions)
-
-    def delete(self, collection_name: str, ids: list[str]):
-        actions = [
-            {"delete": {"_index": f"{self.index_prefix}_{collection_name}", "_id": id}}
-            for id in ids
-        ]
-        self.client.bulk(body=actions)
+            bulk(self.client, actions)
 
+    def delete(self, collection_name: str, ids: Optional[list[str]] = None, filter: Optional[dict] = None):
+        if ids:
+            actions = [
+                {
+                    "_op_type": "delete",
+                    "_index": self._get_index_name(collection_name),
+                    "_id": id,
+                }
+                for id in ids
+            ]
+            bulk(self.client, actions)
+        elif filter:
+            query_body = {
+                "query": {
+                    "bool": {
+                        "filter": []
+                    }
+                },
+            }
+            for field, value in filter.items():
+                query_body["query"]["bool"]["filter"].append({
+                    "match": {
+                        "metadata." + str(field): value
+                    }
+                })
+            self.client.delete_by_query(index=self._get_index_name(collection_name), body=query_body)
+                
     def reset(self):
         indices = self.client.indices.get(index=f"{self.index_prefix}_*")
         for index in indices: