elasticsearch.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. from elasticsearch import Elasticsearch, BadRequestError
  2. from typing import Optional
  3. import ssl
  4. from elasticsearch.helpers import bulk, scan
  5. from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
  6. from open_webui.config import (
  7. ELASTICSEARCH_URL,
  8. ELASTICSEARCH_CA_CERTS,
  9. ELASTICSEARCH_API_KEY,
  10. ELASTICSEARCH_USERNAME,
  11. ELASTICSEARCH_PASSWORD,
  12. ELASTICSEARCH_CLOUD_ID,
  13. SSL_ASSERT_FINGERPRINT,
  14. )
  15. class ElasticsearchClient:
  16. """
  17. Important:
  18. in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating
  19. an index for each file but store it as a text field, while seperating to different index
  20. baesd on the embedding length.
  21. """
  22. def __init__(self):
  23. self.index_prefix = "open_webui_collections"
  24. self.client = Elasticsearch(
  25. hosts=[ELASTICSEARCH_URL],
  26. ca_certs=ELASTICSEARCH_CA_CERTS,
  27. api_key=ELASTICSEARCH_API_KEY,
  28. cloud_id=ELASTICSEARCH_CLOUD_ID,
  29. basic_auth=(
  30. (ELASTICSEARCH_USERNAME, ELASTICSEARCH_PASSWORD)
  31. if ELASTICSEARCH_USERNAME and ELASTICSEARCH_PASSWORD
  32. else None
  33. ),
  34. ssl_assert_fingerprint=SSL_ASSERT_FINGERPRINT,
  35. )
  36. # Status: works
  37. def _get_index_name(self, dimension: int) -> str:
  38. return f"{self.index_prefix}_d{str(dimension)}"
  39. # Status: works
  40. def _scan_result_to_get_result(self, result) -> GetResult:
  41. if not result:
  42. return None
  43. ids = []
  44. documents = []
  45. metadatas = []
  46. for hit in result:
  47. ids.append(hit["_id"])
  48. documents.append(hit["_source"].get("text"))
  49. metadatas.append(hit["_source"].get("metadata"))
  50. return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
  51. # Status: works
  52. def _result_to_get_result(self, result) -> GetResult:
  53. if not result["hits"]["hits"]:
  54. return None
  55. ids = []
  56. documents = []
  57. metadatas = []
  58. for hit in result["hits"]["hits"]:
  59. ids.append(hit["_id"])
  60. documents.append(hit["_source"].get("text"))
  61. metadatas.append(hit["_source"].get("metadata"))
  62. return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
  63. # Status: works
  64. def _result_to_search_result(self, result) -> SearchResult:
  65. ids = []
  66. distances = []
  67. documents = []
  68. metadatas = []
  69. for hit in result["hits"]["hits"]:
  70. ids.append(hit["_id"])
  71. distances.append(hit["_score"])
  72. documents.append(hit["_source"].get("text"))
  73. metadatas.append(hit["_source"].get("metadata"))
  74. return SearchResult(
  75. ids=[ids],
  76. distances=[distances],
  77. documents=[documents],
  78. metadatas=[metadatas],
  79. )
  80. # Status: works
  81. def _create_index(self, dimension: int):
  82. body = {
  83. "mappings": {
  84. "properties": {
  85. "collection": {"type": "keyword"},
  86. "id": {"type": "keyword"},
  87. "vector": {
  88. "type": "dense_vector",
  89. "dims": dimension, # Adjust based on your vector dimensions
  90. "index": True,
  91. "similarity": "cosine",
  92. },
  93. "text": {"type": "text"},
  94. "metadata": {"type": "object"},
  95. }
  96. }
  97. }
  98. self.client.indices.create(index=self._get_index_name(dimension), body=body)
  99. # Status: works
  100. def _create_batches(self, items: list[VectorItem], batch_size=100):
  101. for i in range(0, len(items), batch_size):
  102. yield items[i : min(i + batch_size, len(items))]
  103. # Status: works
  104. def has_collection(self, collection_name) -> bool:
  105. query_body = {"query": {"bool": {"filter": []}}}
  106. query_body["query"]["bool"]["filter"].append(
  107. {"term": {"collection": collection_name}}
  108. )
  109. try:
  110. result = self.client.count(index=f"{self.index_prefix}*", body=query_body)
  111. return result.body["count"] > 0
  112. except Exception as e:
  113. return None
  114. # @TODO: Make this delete a collection and not an index
  115. def delete_colleciton(self, collection_name: str):
  116. # TODO: fix this to include the dimension or a * prefix
  117. # delete_collection here means delete a bunch of documents for an index.
  118. # We are simply adapting to the norms of the other DBs.
  119. self.client.indices.delete(index=self._get_collection_name(collection_name))
  120. # Status: works
  121. def search(
  122. self, collection_name: str, vectors: list[list[float]], limit: int
  123. ) -> Optional[SearchResult]:
  124. query = {
  125. "size": limit,
  126. "_source": ["text", "metadata"],
  127. "query": {
  128. "script_score": {
  129. "query": {
  130. "bool": {"filter": [{"term": {"collection": collection_name}}]}
  131. },
  132. "script": {
  133. "source": "cosineSimilarity(params.vector, 'vector') + 1.0",
  134. "params": {
  135. "vector": vectors[0]
  136. }, # Assuming single query vector
  137. },
  138. }
  139. },
  140. }
  141. result = self.client.search(
  142. index=self._get_index_name(len(vectors[0])), body=query
  143. )
  144. return self._result_to_search_result(result)
  145. # Status: only tested halfwat
  146. def query(
  147. self, collection_name: str, filter: dict, limit: Optional[int] = None
  148. ) -> Optional[GetResult]:
  149. if not self.has_collection(collection_name):
  150. return None
  151. query_body = {
  152. "query": {"bool": {"filter": []}},
  153. "_source": ["text", "metadata"],
  154. }
  155. for field, value in filter.items():
  156. query_body["query"]["bool"]["filter"].append({"term": {field: value}})
  157. query_body["query"]["bool"]["filter"].append(
  158. {"term": {"collection": collection_name}}
  159. )
  160. size = limit if limit else 10
  161. try:
  162. result = self.client.search(
  163. index=f"{self.index_prefix}*",
  164. body=query_body,
  165. size=size,
  166. )
  167. return self._result_to_get_result(result)
  168. except Exception as e:
  169. return None
  170. # Status: works
  171. def _has_index(self, dimension: int):
  172. return self.client.indices.exists(
  173. index=self._get_index_name(dimension=dimension)
  174. )
  175. def get_or_create_index(self, dimension: int):
  176. if not self._has_index(dimension=dimension):
  177. self._create_index(dimension=dimension)
  178. # Status: works
  179. def get(self, collection_name: str) -> Optional[GetResult]:
  180. # Get all the items in the collection.
  181. query = {
  182. "query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}},
  183. "_source": ["text", "metadata"],
  184. }
  185. results = list(scan(self.client, index=f"{self.index_prefix}*", query=query))
  186. return self._scan_result_to_get_result(results)
  187. # Status: works
  188. def insert(self, collection_name: str, items: list[VectorItem]):
  189. if not self._has_index(dimension=len(items[0]["vector"])):
  190. self._create_index(dimension=len(items[0]["vector"]))
  191. for batch in self._create_batches(items):
  192. actions = [
  193. {
  194. "_index": self._get_index_name(dimension=len(items[0]["vector"])),
  195. "_id": item["id"],
  196. "_source": {
  197. "collection": collection_name,
  198. "vector": item["vector"],
  199. "text": item["text"],
  200. "metadata": item["metadata"],
  201. },
  202. }
  203. for item in batch
  204. ]
  205. bulk(self.client, actions)
  206. # Status: should work
  207. def upsert(self, collection_name: str, items: list[VectorItem]):
  208. if not self._has_index(dimension=len(items[0]["vector"])):
  209. self._create_index(collection_name, dimension=len(items[0]["vector"]))
  210. for batch in self._create_batches(items):
  211. actions = [
  212. {
  213. "_index": self._get_index_name(dimension=len(items[0]["vector"])),
  214. "_id": item["id"],
  215. "_source": {
  216. "vector": item["vector"],
  217. "text": item["text"],
  218. "metadata": item["metadata"],
  219. },
  220. }
  221. for item in batch
  222. ]
  223. self.client.bulk(actions)
  224. # TODO: This currently deletes by * which is not always supported in ElasticSearch.
  225. # Need to read a bit before changing. Also, need to delete from a specific collection
  226. def delete(self, collection_name: str, ids: list[str]):
  227. # Assuming ID is unique across collections and indexes
  228. actions = [
  229. {"delete": {"_index": f"{self.index_prefix}*", "_id": id}} for id in ids
  230. ]
  231. self.client.bulk(body=actions)
  232. def reset(self):
  233. indices = self.client.indices.get(index=f"{self.index_prefix}*")
  234. for index in indices:
  235. self.client.indices.delete(index=index)