pgvector.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. from typing import Optional, List, Dict, Any
  2. from sqlalchemy import (
  3. cast,
  4. column,
  5. Column,
  6. Integer,
  7. select,
  8. text,
  9. Text,
  10. values,
  11. )
  12. from sqlalchemy.sql import true
  13. from sqlalchemy.orm import declarative_base, Session
  14. from sqlalchemy.dialects.postgresql import JSONB, array
  15. from pgvector.sqlalchemy import Vector
  16. from sqlalchemy.ext.mutable import MutableDict
  17. from sqlalchemy.ext.declarative import declarative_base
  18. from open_webui.apps.webui.internal.db import Session
  19. from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
  20. VECTOR_LENGTH = 1536
  21. Base = declarative_base()
  22. class DocumentChunk(Base):
  23. __tablename__ = "document_chunk"
  24. id = Column(Text, primary_key=True)
  25. vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
  26. collection_name = Column(Text, nullable=False)
  27. text = Column(Text, nullable=True)
  28. vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
  29. class PgvectorClient:
  30. def __init__(self) -> None:
  31. self.session = Session
  32. try:
  33. # Ensure the pgvector extension is available
  34. self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
  35. # Create the tables if they do not exist
  36. # Base.metadata.create_all requires a bind (engine or connection)
  37. # Get the connection from the session
  38. connection = self.session.connection()
  39. Base.metadata.create_all(bind=connection)
  40. # Create an index on the vector column if it doesn't exist
  41. self.session.execute(
  42. text(
  43. "CREATE INDEX IF NOT EXISTS idx_document_chunk_vector "
  44. "ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);"
  45. )
  46. )
  47. self.session.execute(
  48. text(
  49. "CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
  50. "ON document_chunk (collection_name);"
  51. )
  52. )
  53. self.session.commit()
  54. print("Initialization complete.")
  55. except Exception as e:
  56. self.session.rollback()
  57. print(f"Error during initialization: {e}")
  58. raise
  59. def adjust_vector_length(self, vector: List[float]) -> List[float]:
  60. # Adjust vector to have length VECTOR_LENGTH
  61. current_length = len(vector)
  62. if current_length < VECTOR_LENGTH:
  63. # Pad the vector with zeros
  64. vector += [0.0] * (VECTOR_LENGTH - current_length)
  65. elif current_length > VECTOR_LENGTH:
  66. raise Exception(
  67. f"Vector length {current_length} not supported. Max length must be <= {VECTOR_LENGTH}"
  68. )
  69. return vector
  70. def insert(self, collection_name: str, items: List[VectorItem]) -> None:
  71. try:
  72. new_items = []
  73. for item in items:
  74. vector = self.adjust_vector_length(item["vector"])
  75. new_chunk = DocumentChunk(
  76. id=item["id"],
  77. vector=vector,
  78. collection_name=collection_name,
  79. text=item["text"],
  80. vmetadata=item["metadata"],
  81. )
  82. new_items.append(new_chunk)
  83. self.session.bulk_save_objects(new_items)
  84. self.session.commit()
  85. print(
  86. f"Inserted {len(new_items)} items into collection '{collection_name}'."
  87. )
  88. except Exception as e:
  89. self.session.rollback()
  90. print(f"Error during insert: {e}")
  91. raise
  92. def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
  93. try:
  94. for item in items:
  95. vector = self.adjust_vector_length(item["vector"])
  96. existing = (
  97. self.session.query(DocumentChunk)
  98. .filter(DocumentChunk.id == item["id"])
  99. .first()
  100. )
  101. if existing:
  102. existing.vector = vector
  103. existing.text = item["text"]
  104. existing.vmetadata = item["metadata"]
  105. existing.collection_name = (
  106. collection_name # Update collection_name if necessary
  107. )
  108. else:
  109. new_chunk = DocumentChunk(
  110. id=item["id"],
  111. vector=vector,
  112. collection_name=collection_name,
  113. text=item["text"],
  114. vmetadata=item["metadata"],
  115. )
  116. self.session.add(new_chunk)
  117. self.session.commit()
  118. print(f"Upserted {len(items)} items into collection '{collection_name}'.")
  119. except Exception as e:
  120. self.session.rollback()
  121. print(f"Error during upsert: {e}")
  122. raise
  123. def search(
  124. self,
  125. collection_name: str,
  126. vectors: List[List[float]],
  127. limit: Optional[int] = None,
  128. ) -> Optional[SearchResult]:
  129. try:
  130. if not vectors:
  131. return None
  132. # Adjust query vectors to VECTOR_LENGTH
  133. vectors = [self.adjust_vector_length(vector) for vector in vectors]
  134. num_queries = len(vectors)
  135. def vector_expr(vector):
  136. return cast(array(vector), Vector(VECTOR_LENGTH))
  137. # Create the values for query vectors
  138. qid_col = column("qid", Integer)
  139. q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
  140. query_vectors = (
  141. values(qid_col, q_vector_col)
  142. .data(
  143. [(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]
  144. )
  145. .alias("query_vectors")
  146. )
  147. # Build the lateral subquery for each query vector
  148. subq = (
  149. select(
  150. DocumentChunk.id,
  151. DocumentChunk.text,
  152. DocumentChunk.vmetadata,
  153. (
  154. DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)
  155. ).label("distance"),
  156. )
  157. .where(DocumentChunk.collection_name == collection_name)
  158. .order_by(
  159. (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
  160. )
  161. )
  162. if limit is not None:
  163. subq = subq.limit(limit)
  164. subq = subq.lateral("result")
  165. # Build the main query by joining query_vectors and the lateral subquery
  166. stmt = (
  167. select(
  168. query_vectors.c.qid,
  169. subq.c.id,
  170. subq.c.text,
  171. subq.c.vmetadata,
  172. subq.c.distance,
  173. )
  174. .select_from(query_vectors)
  175. .join(subq, true())
  176. .order_by(query_vectors.c.qid, subq.c.distance)
  177. )
  178. result_proxy = self.session.execute(stmt)
  179. results = result_proxy.all()
  180. ids = [[] for _ in range(num_queries)]
  181. distances = [[] for _ in range(num_queries)]
  182. documents = [[] for _ in range(num_queries)]
  183. metadatas = [[] for _ in range(num_queries)]
  184. if not results:
  185. return SearchResult(
  186. ids=ids,
  187. distances=distances,
  188. documents=documents,
  189. metadatas=metadatas,
  190. )
  191. for row in results:
  192. qid = int(row.qid)
  193. ids[qid].append(row.id)
  194. distances[qid].append(row.distance)
  195. documents[qid].append(row.text)
  196. metadatas[qid].append(row.vmetadata)
  197. return SearchResult(
  198. ids=ids, distances=distances, documents=documents, metadatas=metadatas
  199. )
  200. except Exception as e:
  201. print(f"Error during search: {e}")
  202. return None
  203. def query(
  204. self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
  205. ) -> Optional[GetResult]:
  206. try:
  207. query = self.session.query(DocumentChunk).filter(
  208. DocumentChunk.collection_name == collection_name
  209. )
  210. for key, value in filter.items():
  211. query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
  212. if limit is not None:
  213. query = query.limit(limit)
  214. results = query.all()
  215. if not results:
  216. return None
  217. ids = [[result.id for result in results]]
  218. documents = [[result.text for result in results]]
  219. metadatas = [[result.vmetadata for result in results]]
  220. return GetResult(
  221. ids=ids,
  222. documents=documents,
  223. metadatas=metadatas,
  224. )
  225. except Exception as e:
  226. print(f"Error during query: {e}")
  227. return None
  228. def get(
  229. self, collection_name: str, limit: Optional[int] = None
  230. ) -> Optional[GetResult]:
  231. try:
  232. query = self.session.query(DocumentChunk).filter(
  233. DocumentChunk.collection_name == collection_name
  234. )
  235. if limit is not None:
  236. query = query.limit(limit)
  237. results = query.all()
  238. if not results:
  239. return None
  240. ids = [[result.id for result in results]]
  241. documents = [[result.text for result in results]]
  242. metadatas = [[result.vmetadata for result in results]]
  243. return GetResult(ids=ids, documents=documents, metadatas=metadatas)
  244. except Exception as e:
  245. print(f"Error during get: {e}")
  246. return None
  247. def delete(
  248. self,
  249. collection_name: str,
  250. ids: Optional[List[str]] = None,
  251. filter: Optional[Dict[str, Any]] = None,
  252. ) -> None:
  253. try:
  254. query = self.session.query(DocumentChunk).filter(
  255. DocumentChunk.collection_name == collection_name
  256. )
  257. if ids:
  258. query = query.filter(DocumentChunk.id.in_(ids))
  259. if filter:
  260. for key, value in filter.items():
  261. query = query.filter(
  262. DocumentChunk.vmetadata[key].astext == str(value)
  263. )
  264. deleted = query.delete(synchronize_session=False)
  265. self.session.commit()
  266. print(f"Deleted {deleted} items from collection '{collection_name}'.")
  267. except Exception as e:
  268. self.session.rollback()
  269. print(f"Error during delete: {e}")
  270. raise
  271. def reset(self) -> None:
  272. try:
  273. deleted = self.session.query(DocumentChunk).delete()
  274. self.session.commit()
  275. print(
  276. f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
  277. )
  278. except Exception as e:
  279. self.session.rollback()
  280. print(f"Error during reset: {e}")
  281. raise
  282. def close(self) -> None:
  283. pass
  284. def has_collection(self, collection_name: str) -> bool:
  285. try:
  286. exists = (
  287. self.session.query(DocumentChunk)
  288. .filter(DocumentChunk.collection_name == collection_name)
  289. .first()
  290. is not None
  291. )
  292. return exists
  293. except Exception as e:
  294. print(f"Error checking collection existence: {e}")
  295. return False
  296. def delete_collection(self, collection_name: str) -> None:
  297. self.delete(collection_name)
  298. print(f"Collection '{collection_name}' deleted.")