elasticsearch.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. from elasticsearch import Elasticsearch, BadRequestError
  2. from typing import Optional
  3. import ssl
  4. from elasticsearch.helpers import bulk, scan
  5. from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
  6. from open_webui.config import (
  7. ELASTICSEARCH_URL,
  8. ELASTICSEARCH_CA_CERTS,
  9. ELASTICSEARCH_API_KEY,
  10. ELASTICSEARCH_USERNAME,
  11. ELASTICSEARCH_PASSWORD,
  12. ELASTICSEARCH_CLOUD_ID,
  13. ELASTICSEARCH_INDEX_PREFIX,
  14. SSL_ASSERT_FINGERPRINT,
  15. )
  16. class ElasticsearchClient:
  17. """
  18. Important:
  19. in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating
  20. an index for each file but store it as a text field, while seperating to different index
  21. baesd on the embedding length.
  22. """
  23. def __init__(self):
  24. self.index_prefix = ELASTICSEARCH_INDEX_PREFIX
  25. self.client = Elasticsearch(
  26. hosts=[ELASTICSEARCH_URL],
  27. ca_certs=ELASTICSEARCH_CA_CERTS,
  28. api_key=ELASTICSEARCH_API_KEY,
  29. cloud_id=ELASTICSEARCH_CLOUD_ID,
  30. basic_auth=(
  31. (ELASTICSEARCH_USERNAME, ELASTICSEARCH_PASSWORD)
  32. if ELASTICSEARCH_USERNAME and ELASTICSEARCH_PASSWORD
  33. else None
  34. ),
  35. ssl_assert_fingerprint=SSL_ASSERT_FINGERPRINT,
  36. )
  37. # Status: works
  38. def _get_index_name(self, dimension: int) -> str:
  39. return f"{self.index_prefix}_d{str(dimension)}"
  40. # Status: works
  41. def _scan_result_to_get_result(self, result) -> GetResult:
  42. if not result:
  43. return None
  44. ids = []
  45. documents = []
  46. metadatas = []
  47. for hit in result:
  48. ids.append(hit["_id"])
  49. documents.append(hit["_source"].get("text"))
  50. metadatas.append(hit["_source"].get("metadata"))
  51. return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
  52. # Status: works
  53. def _result_to_get_result(self, result) -> GetResult:
  54. if not result["hits"]["hits"]:
  55. return None
  56. ids = []
  57. documents = []
  58. metadatas = []
  59. for hit in result["hits"]["hits"]:
  60. ids.append(hit["_id"])
  61. documents.append(hit["_source"].get("text"))
  62. metadatas.append(hit["_source"].get("metadata"))
  63. return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
  64. # Status: works
  65. def _result_to_search_result(self, result) -> SearchResult:
  66. ids = []
  67. distances = []
  68. documents = []
  69. metadatas = []
  70. for hit in result["hits"]["hits"]:
  71. ids.append(hit["_id"])
  72. distances.append(hit["_score"])
  73. documents.append(hit["_source"].get("text"))
  74. metadatas.append(hit["_source"].get("metadata"))
  75. return SearchResult(
  76. ids=[ids],
  77. distances=[distances],
  78. documents=[documents],
  79. metadatas=[metadatas],
  80. )
  81. # Status: works
  82. def _create_index(self, dimension: int):
  83. body = {
  84. "mappings": {
  85. "dynamic_templates": [
  86. {
  87. "strings": {
  88. "match_mapping_type": "string",
  89. "mapping": {"type": "keyword"},
  90. }
  91. }
  92. ],
  93. "properties": {
  94. "collection": {"type": "keyword"},
  95. "id": {"type": "keyword"},
  96. "vector": {
  97. "type": "dense_vector",
  98. "dims": dimension, # Adjust based on your vector dimensions
  99. "index": True,
  100. "similarity": "cosine",
  101. },
  102. "text": {"type": "text"},
  103. "metadata": {"type": "object"},
  104. },
  105. }
  106. }
  107. self.client.indices.create(index=self._get_index_name(dimension), body=body)
  108. # Status: works
  109. def _create_batches(self, items: list[VectorItem], batch_size=100):
  110. for i in range(0, len(items), batch_size):
  111. yield items[i : min(i + batch_size, len(items))]
  112. # Status: works
  113. def has_collection(self, collection_name) -> bool:
  114. query_body = {"query": {"bool": {"filter": []}}}
  115. query_body["query"]["bool"]["filter"].append(
  116. {"term": {"collection": collection_name}}
  117. )
  118. try:
  119. result = self.client.count(index=f"{self.index_prefix}*", body=query_body)
  120. return result.body["count"] > 0
  121. except Exception as e:
  122. return None
  123. def delete_collection(self, collection_name: str):
  124. query = {"query": {"term": {"collection": collection_name}}}
  125. self.client.delete_by_query(index=f"{self.index_prefix}*", body=query)
  126. # Status: works
  127. def search(
  128. self, collection_name: str, vectors: list[list[float]], limit: int
  129. ) -> Optional[SearchResult]:
  130. query = {
  131. "size": limit,
  132. "_source": ["text", "metadata"],
  133. "query": {
  134. "script_score": {
  135. "query": {
  136. "bool": {"filter": [{"term": {"collection": collection_name}}]}
  137. },
  138. "script": {
  139. "source": "cosineSimilarity(params.vector, 'vector') + 1.0",
  140. "params": {
  141. "vector": vectors[0]
  142. }, # Assuming single query vector
  143. },
  144. }
  145. },
  146. }
  147. result = self.client.search(
  148. index=self._get_index_name(len(vectors[0])), body=query
  149. )
  150. return self._result_to_search_result(result)
  151. # Status: only tested halfwat
  152. def query(
  153. self, collection_name: str, filter: dict, limit: Optional[int] = None
  154. ) -> Optional[GetResult]:
  155. if not self.has_collection(collection_name):
  156. return None
  157. query_body = {
  158. "query": {"bool": {"filter": []}},
  159. "_source": ["text", "metadata"],
  160. }
  161. for field, value in filter.items():
  162. query_body["query"]["bool"]["filter"].append({"term": {field: value}})
  163. query_body["query"]["bool"]["filter"].append(
  164. {"term": {"collection": collection_name}}
  165. )
  166. size = limit if limit else 10
  167. try:
  168. result = self.client.search(
  169. index=f"{self.index_prefix}*",
  170. body=query_body,
  171. size=size,
  172. )
  173. return self._result_to_get_result(result)
  174. except Exception as e:
  175. return None
  176. # Status: works
  177. def _has_index(self, dimension: int):
  178. return self.client.indices.exists(
  179. index=self._get_index_name(dimension=dimension)
  180. )
  181. def get_or_create_index(self, dimension: int):
  182. if not self._has_index(dimension=dimension):
  183. self._create_index(dimension=dimension)
  184. # Status: works
  185. def get(self, collection_name: str) -> Optional[GetResult]:
  186. # Get all the items in the collection.
  187. query = {
  188. "query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}},
  189. "_source": ["text", "metadata"],
  190. }
  191. results = list(scan(self.client, index=f"{self.index_prefix}*", query=query))
  192. return self._scan_result_to_get_result(results)
  193. # Status: works
  194. def insert(self, collection_name: str, items: list[VectorItem]):
  195. if not self._has_index(dimension=len(items[0]["vector"])):
  196. self._create_index(dimension=len(items[0]["vector"]))
  197. for batch in self._create_batches(items):
  198. actions = [
  199. {
  200. "_index": self._get_index_name(dimension=len(items[0]["vector"])),
  201. "_id": item["id"],
  202. "_source": {
  203. "collection": collection_name,
  204. "vector": item["vector"],
  205. "text": item["text"],
  206. "metadata": item["metadata"],
  207. },
  208. }
  209. for item in batch
  210. ]
  211. bulk(self.client, actions)
  212. # Upsert documents using the update API with doc_as_upsert=True.
  213. def upsert(self, collection_name: str, items: list[VectorItem]):
  214. if not self._has_index(dimension=len(items[0]["vector"])):
  215. self._create_index(dimension=len(items[0]["vector"]))
  216. for batch in self._create_batches(items):
  217. actions = [
  218. {
  219. "_op_type": "update",
  220. "_index": self._get_index_name(dimension=len(item["vector"])),
  221. "_id": item["id"],
  222. "doc": {
  223. "collection": collection_name,
  224. "vector": item["vector"],
  225. "text": item["text"],
  226. "metadata": item["metadata"],
  227. },
  228. "doc_as_upsert": True,
  229. }
  230. for item in batch
  231. ]
  232. bulk(self.client, actions)
  233. # Delete specific documents from a collection by filtering on both collection and document IDs.
  234. def delete(
  235. self,
  236. collection_name: str,
  237. ids: Optional[list[str]] = None,
  238. filter: Optional[dict] = None,
  239. ):
  240. query = {
  241. "query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}}
  242. }
  243. # logic based on chromaDB
  244. if ids:
  245. query["query"]["bool"]["filter"].append({"terms": {"_id": ids}})
  246. elif filter:
  247. for field, value in filter.items():
  248. query["query"]["bool"]["filter"].append(
  249. {"term": {f"metadata.{field}": value}}
  250. )
  251. self.client.delete_by_query(index=f"{self.index_prefix}*", body=query)
  252. def reset(self):
  253. indices = self.client.indices.get(index=f"{self.index_prefix}*")
  254. for index in indices:
  255. self.client.indices.delete(index=index)