opensearch.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. from opensearchpy import OpenSearch
  2. from opensearchpy.helpers import bulk
  3. from typing import Optional
  4. from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
  5. from open_webui.config import (
  6. OPENSEARCH_URI,
  7. OPENSEARCH_SSL,
  8. OPENSEARCH_CERT_VERIFY,
  9. OPENSEARCH_USERNAME,
  10. OPENSEARCH_PASSWORD,
  11. )
  12. class OpenSearchClient:
  13. def __init__(self):
  14. self.index_prefix = "open_webui"
  15. self.client = OpenSearch(
  16. hosts=[OPENSEARCH_URI],
  17. use_ssl=OPENSEARCH_SSL,
  18. verify_certs=OPENSEARCH_CERT_VERIFY,
  19. http_auth=(OPENSEARCH_USERNAME, OPENSEARCH_PASSWORD),
  20. )
  21. def _get_index_name(self, collection_name: str) -> str:
  22. return f"{self.index_prefix}_{collection_name}"
  23. def _result_to_get_result(self, result) -> GetResult:
  24. if not result["hits"]["hits"]:
  25. return None
  26. ids = []
  27. documents = []
  28. metadatas = []
  29. for hit in result["hits"]["hits"]:
  30. ids.append(hit["_id"])
  31. documents.append(hit["_source"].get("text"))
  32. metadatas.append(hit["_source"].get("metadata"))
  33. return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
  34. def _result_to_search_result(self, result) -> SearchResult:
  35. if not result["hits"]["hits"]:
  36. return None
  37. ids = []
  38. distances = []
  39. documents = []
  40. metadatas = []
  41. for hit in result["hits"]["hits"]:
  42. ids.append(hit["_id"])
  43. distances.append(hit["_score"])
  44. documents.append(hit["_source"].get("text"))
  45. metadatas.append(hit["_source"].get("metadata"))
  46. return SearchResult(
  47. ids=[ids], distances=[distances], documents=[documents], metadatas=[metadatas]
  48. )
  49. def _create_index(self, collection_name: str, dimension: int):
  50. body = {
  51. "settings": {
  52. "index": {
  53. "knn": True
  54. }
  55. },
  56. "mappings": {
  57. "properties": {
  58. "id": {"type": "keyword"},
  59. "vector": {
  60. "type": "knn_vector",
  61. "dimension": dimension, # Adjust based on your vector dimensions
  62. "index": True,
  63. "similarity": "faiss",
  64. "method": {
  65. "name": "hnsw",
  66. "space_type": "innerproduct", # Use inner product to approximate cosine similarity
  67. "engine": "faiss",
  68. "parameters": {
  69. "ef_construction": 128,
  70. "m": 16,
  71. }
  72. },
  73. },
  74. "text": {"type": "text"},
  75. "metadata": {"type": "object"},
  76. }
  77. }
  78. }
  79. self.client.indices.create(
  80. index=self._get_index_name(collection_name), body=body
  81. )
  82. def _create_batches(self, items: list[VectorItem], batch_size=100):
  83. for i in range(0, len(items), batch_size):
  84. yield items[i : i + batch_size]
  85. def has_collection(self, collection_name: str) -> bool:
  86. # has_collection here means has index.
  87. # We are simply adapting to the norms of the other DBs.
  88. return self.client.indices.exists(
  89. index=self._get_index_name(collection_name)
  90. )
  91. def delete_collection(self, collection_name: str):
  92. # delete_collection here means delete index.
  93. # We are simply adapting to the norms of the other DBs.
  94. self.client.indices.delete(index=self._get_index_name(collection_name))
  95. def search(
  96. self, collection_name: str, vectors: list[list[float | int]], limit: int
  97. ) -> Optional[SearchResult]:
  98. try:
  99. if not self.has_collection(collection_name):
  100. return None
  101. query = {
  102. "size": limit,
  103. "_source": ["text", "metadata"],
  104. "query": {
  105. "script_score": {
  106. "query": {
  107. "match_all": {}
  108. },
  109. "script": {
  110. "source": "cosineSimilarity(params.query_value, doc[params.field]) + 1.0",
  111. "params": {
  112. "field": "vector",
  113. "query_value": vectors[0]
  114. }, # Assuming single query vector
  115. },
  116. }
  117. },
  118. }
  119. result = self.client.search(
  120. index=self._get_index_name(collection_name),
  121. body=query
  122. )
  123. return self._result_to_search_result(result)
  124. except Exception as e:
  125. return None
  126. def query(
  127. self, collection_name: str, filter: dict, limit: Optional[int] = None
  128. ) -> Optional[GetResult]:
  129. if not self.has_collection(collection_name):
  130. return None
  131. query_body = {
  132. "query": {
  133. "bool": {
  134. "filter": []
  135. }
  136. },
  137. "_source": ["text", "metadata"],
  138. }
  139. for field, value in filter.items():
  140. query_body["query"]["bool"]["filter"].append({
  141. "match": {
  142. "metadata." + str(field): value
  143. }
  144. })
  145. size = limit if limit else 10
  146. try:
  147. result = self.client.search(
  148. index=self._get_index_name(collection_name),
  149. body=query_body,
  150. size=size,
  151. )
  152. return self._result_to_get_result(result)
  153. except Exception as e:
  154. return None
  155. def _create_index_if_not_exists(self, collection_name: str, dimension: int):
  156. if not self.has_collection(collection_name):
  157. self._create_index(collection_name, dimension)
  158. def get(self, collection_name: str) -> Optional[GetResult]:
  159. query = {"query": {"match_all": {}}, "_source": ["text", "metadata"]}
  160. result = self.client.search(
  161. index=self._get_index_name(collection_name), body=query
  162. )
  163. return self._result_to_get_result(result)
  164. def insert(self, collection_name: str, items: list[VectorItem]):
  165. self._create_index_if_not_exists(
  166. collection_name=collection_name, dimension=len(items[0]["vector"])
  167. )
  168. for batch in self._create_batches(items):
  169. actions = [
  170. {
  171. "_op_type": "index",
  172. "_index": self._get_index_name(collection_name),
  173. "_id": item["id"],
  174. "_source": {
  175. "vector": item["vector"],
  176. "text": item["text"],
  177. "metadata": item["metadata"],
  178. },
  179. }
  180. for item in batch
  181. ]
  182. bulk(self.client, actions)
  183. def upsert(self, collection_name: str, items: list[VectorItem]):
  184. self._create_index_if_not_exists(
  185. collection_name=collection_name, dimension=len(items[0]["vector"])
  186. )
  187. for batch in self._create_batches(items):
  188. actions = [
  189. {
  190. "_op_type": "update",
  191. "_index": self._get_index_name(collection_name),
  192. "_id": item["id"],
  193. "doc": {
  194. "vector": item["vector"],
  195. "text": item["text"],
  196. "metadata": item["metadata"],
  197. },
  198. "doc_as_upsert": True,
  199. }
  200. for item in batch
  201. ]
  202. bulk(self.client, actions)
  203. def delete(self, collection_name: str, ids: Optional[list[str]] = None, filter: Optional[dict] = None):
  204. if ids:
  205. actions = [
  206. {
  207. "_op_type": "delete",
  208. "_index": self._get_index_name(collection_name),
  209. "_id": id,
  210. }
  211. for id in ids
  212. ]
  213. bulk(self.client, actions)
  214. elif filter:
  215. query_body = {
  216. "query": {
  217. "bool": {
  218. "filter": []
  219. }
  220. },
  221. }
  222. for field, value in filter.items():
  223. query_body["query"]["bool"]["filter"].append({
  224. "match": {
  225. "metadata." + str(field): value
  226. }
  227. })
  228. self.client.delete_by_query(index=self._get_index_name(collection_name), body=query_body)
  229. def reset(self):
  230. indices = self.client.indices.get(index=f"{self.index_prefix}_*")
  231. for index in indices:
  232. self.client.indices.delete(index=index)