pgvector.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. from typing import Optional, List, Dict, Any
  2. import logging
  3. from sqlalchemy import (
  4. cast,
  5. column,
  6. create_engine,
  7. Column,
  8. Integer,
  9. MetaData,
  10. select,
  11. text,
  12. Text,
  13. Table,
  14. values,
  15. )
  16. from sqlalchemy.sql import true
  17. from sqlalchemy.pool import NullPool
  18. from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
  19. from sqlalchemy.dialects.postgresql import JSONB, array
  20. from pgvector.sqlalchemy import Vector
  21. from sqlalchemy.ext.mutable import MutableDict
  22. from sqlalchemy.exc import NoSuchTableError
  23. from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
  24. from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
  25. from open_webui.env import SRC_LOG_LEVELS
  26. VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
  27. Base = declarative_base()
  28. log = logging.getLogger(__name__)
  29. log.setLevel(SRC_LOG_LEVELS["RAG"])
  30. class DocumentChunk(Base):
  31. __tablename__ = "document_chunk"
  32. id = Column(Text, primary_key=True)
  33. vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
  34. collection_name = Column(Text, nullable=False)
  35. text = Column(Text, nullable=True)
  36. vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
  37. class PgvectorClient:
  38. def __init__(self) -> None:
  39. # if no pgvector uri, use the existing database connection
  40. if not PGVECTOR_DB_URL:
  41. from open_webui.internal.db import Session
  42. self.session = Session
  43. else:
  44. engine = create_engine(
  45. PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool
  46. )
  47. SessionLocal = sessionmaker(
  48. autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
  49. )
  50. self.session = scoped_session(SessionLocal)
  51. try:
  52. # Ensure the pgvector extension is available
  53. self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
  54. # Check vector length consistency
  55. self.check_vector_length()
  56. # Create the tables if they do not exist
  57. # Base.metadata.create_all requires a bind (engine or connection)
  58. # Get the connection from the session
  59. connection = self.session.connection()
  60. Base.metadata.create_all(bind=connection)
  61. # Create an index on the vector column if it doesn't exist
  62. self.session.execute(
  63. text(
  64. "CREATE INDEX IF NOT EXISTS idx_document_chunk_vector "
  65. "ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);"
  66. )
  67. )
  68. self.session.execute(
  69. text(
  70. "CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
  71. "ON document_chunk (collection_name);"
  72. )
  73. )
  74. self.session.commit()
  75. log.info("Initialization complete.")
  76. except Exception as e:
  77. self.session.rollback()
  78. log.exception(f"Error during initialization: {e}")
  79. raise
  80. def check_vector_length(self) -> None:
  81. """
  82. Check if the VECTOR_LENGTH matches the existing vector column dimension in the database.
  83. Raises an exception if there is a mismatch.
  84. """
  85. metadata = MetaData()
  86. try:
  87. # Attempt to reflect the 'document_chunk' table
  88. document_chunk_table = Table(
  89. "document_chunk", metadata, autoload_with=self.session.bind
  90. )
  91. except NoSuchTableError:
  92. # Table does not exist; no action needed
  93. return
  94. # Proceed to check the vector column
  95. if "vector" in document_chunk_table.columns:
  96. vector_column = document_chunk_table.columns["vector"]
  97. vector_type = vector_column.type
  98. if isinstance(vector_type, Vector):
  99. db_vector_length = vector_type.dim
  100. if db_vector_length != VECTOR_LENGTH:
  101. raise Exception(
  102. f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. "
  103. "Cannot change vector size after initialization without migrating the data."
  104. )
  105. else:
  106. raise Exception(
  107. "The 'vector' column exists but is not of type 'Vector'."
  108. )
  109. else:
  110. raise Exception(
  111. "The 'vector' column does not exist in the 'document_chunk' table."
  112. )
  113. def adjust_vector_length(self, vector: List[float]) -> List[float]:
  114. # Adjust vector to have length VECTOR_LENGTH
  115. current_length = len(vector)
  116. if current_length < VECTOR_LENGTH:
  117. # Pad the vector with zeros
  118. vector += [0.0] * (VECTOR_LENGTH - current_length)
  119. elif current_length > VECTOR_LENGTH:
  120. raise Exception(
  121. f"Vector length {current_length} not supported. Max length must be <= {VECTOR_LENGTH}"
  122. )
  123. return vector
  124. def insert(self, collection_name: str, items: List[VectorItem]) -> None:
  125. try:
  126. new_items = []
  127. for item in items:
  128. vector = self.adjust_vector_length(item["vector"])
  129. new_chunk = DocumentChunk(
  130. id=item["id"],
  131. vector=vector,
  132. collection_name=collection_name,
  133. text=item["text"],
  134. vmetadata=item["metadata"],
  135. )
  136. new_items.append(new_chunk)
  137. self.session.bulk_save_objects(new_items)
  138. self.session.commit()
  139. log.info(
  140. f"Inserted {len(new_items)} items into collection '{collection_name}'."
  141. )
  142. except Exception as e:
  143. self.session.rollback()
  144. log.exception(f"Error during insert: {e}")
  145. raise
  146. def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
  147. try:
  148. for item in items:
  149. vector = self.adjust_vector_length(item["vector"])
  150. existing = (
  151. self.session.query(DocumentChunk)
  152. .filter(DocumentChunk.id == item["id"])
  153. .first()
  154. )
  155. if existing:
  156. existing.vector = vector
  157. existing.text = item["text"]
  158. existing.vmetadata = item["metadata"]
  159. existing.collection_name = (
  160. collection_name # Update collection_name if necessary
  161. )
  162. else:
  163. new_chunk = DocumentChunk(
  164. id=item["id"],
  165. vector=vector,
  166. collection_name=collection_name,
  167. text=item["text"],
  168. vmetadata=item["metadata"],
  169. )
  170. self.session.add(new_chunk)
  171. self.session.commit()
  172. log.info(
  173. f"Upserted {len(items)} items into collection '{collection_name}'."
  174. )
  175. except Exception as e:
  176. self.session.rollback()
  177. log.exception(f"Error during upsert: {e}")
  178. raise
  179. def search(
  180. self,
  181. collection_name: str,
  182. vectors: List[List[float]],
  183. limit: Optional[int] = None,
  184. ) -> Optional[SearchResult]:
  185. try:
  186. if not vectors:
  187. return None
  188. # Adjust query vectors to VECTOR_LENGTH
  189. vectors = [self.adjust_vector_length(vector) for vector in vectors]
  190. num_queries = len(vectors)
  191. def vector_expr(vector):
  192. return cast(array(vector), Vector(VECTOR_LENGTH))
  193. # Create the values for query vectors
  194. qid_col = column("qid", Integer)
  195. q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
  196. query_vectors = (
  197. values(qid_col, q_vector_col)
  198. .data(
  199. [(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]
  200. )
  201. .alias("query_vectors")
  202. )
  203. # Build the lateral subquery for each query vector
  204. subq = (
  205. select(
  206. DocumentChunk.id,
  207. DocumentChunk.text,
  208. DocumentChunk.vmetadata,
  209. (
  210. DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)
  211. ).label("distance"),
  212. )
  213. .where(DocumentChunk.collection_name == collection_name)
  214. .order_by(
  215. (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
  216. )
  217. )
  218. if limit is not None:
  219. subq = subq.limit(limit)
  220. subq = subq.lateral("result")
  221. # Build the main query by joining query_vectors and the lateral subquery
  222. stmt = (
  223. select(
  224. query_vectors.c.qid,
  225. subq.c.id,
  226. subq.c.text,
  227. subq.c.vmetadata,
  228. subq.c.distance,
  229. )
  230. .select_from(query_vectors)
  231. .join(subq, true())
  232. .order_by(query_vectors.c.qid, subq.c.distance)
  233. )
  234. result_proxy = self.session.execute(stmt)
  235. results = result_proxy.all()
  236. ids = [[] for _ in range(num_queries)]
  237. distances = [[] for _ in range(num_queries)]
  238. documents = [[] for _ in range(num_queries)]
  239. metadatas = [[] for _ in range(num_queries)]
  240. if not results:
  241. return SearchResult(
  242. ids=ids,
  243. distances=distances,
  244. documents=documents,
  245. metadatas=metadatas,
  246. )
  247. for row in results:
  248. qid = int(row.qid)
  249. ids[qid].append(row.id)
  250. distances[qid].append(row.distance)
  251. documents[qid].append(row.text)
  252. metadatas[qid].append(row.vmetadata)
  253. return SearchResult(
  254. ids=ids, distances=distances, documents=documents, metadatas=metadatas
  255. )
  256. except Exception as e:
  257. log.exception(f"Error during search: {e}")
  258. return None
  259. def query(
  260. self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
  261. ) -> Optional[GetResult]:
  262. try:
  263. query = self.session.query(DocumentChunk).filter(
  264. DocumentChunk.collection_name == collection_name
  265. )
  266. for key, value in filter.items():
  267. query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
  268. if limit is not None:
  269. query = query.limit(limit)
  270. results = query.all()
  271. if not results:
  272. return None
  273. ids = [[result.id for result in results]]
  274. documents = [[result.text for result in results]]
  275. metadatas = [[result.vmetadata for result in results]]
  276. return GetResult(
  277. ids=ids,
  278. documents=documents,
  279. metadatas=metadatas,
  280. )
  281. except Exception as e:
  282. log.exception(f"Error during query: {e}")
  283. return None
  284. def get(
  285. self, collection_name: str, limit: Optional[int] = None
  286. ) -> Optional[GetResult]:
  287. try:
  288. query = self.session.query(DocumentChunk).filter(
  289. DocumentChunk.collection_name == collection_name
  290. )
  291. if limit is not None:
  292. query = query.limit(limit)
  293. results = query.all()
  294. if not results:
  295. return None
  296. ids = [[result.id for result in results]]
  297. documents = [[result.text for result in results]]
  298. metadatas = [[result.vmetadata for result in results]]
  299. return GetResult(ids=ids, documents=documents, metadatas=metadatas)
  300. except Exception as e:
  301. log.exception(f"Error during get: {e}")
  302. return None
  303. def delete(
  304. self,
  305. collection_name: str,
  306. ids: Optional[List[str]] = None,
  307. filter: Optional[Dict[str, Any]] = None,
  308. ) -> None:
  309. try:
  310. query = self.session.query(DocumentChunk).filter(
  311. DocumentChunk.collection_name == collection_name
  312. )
  313. if ids:
  314. query = query.filter(DocumentChunk.id.in_(ids))
  315. if filter:
  316. for key, value in filter.items():
  317. query = query.filter(
  318. DocumentChunk.vmetadata[key].astext == str(value)
  319. )
  320. deleted = query.delete(synchronize_session=False)
  321. self.session.commit()
  322. log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
  323. except Exception as e:
  324. self.session.rollback()
  325. log.exception(f"Error during delete: {e}")
  326. raise
  327. def reset(self) -> None:
  328. try:
  329. deleted = self.session.query(DocumentChunk).delete()
  330. self.session.commit()
  331. log.info(
  332. f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
  333. )
  334. except Exception as e:
  335. self.session.rollback()
  336. log.exception(f"Error during reset: {e}")
  337. raise
  338. def close(self) -> None:
  339. pass
  340. def has_collection(self, collection_name: str) -> bool:
  341. try:
  342. exists = (
  343. self.session.query(DocumentChunk)
  344. .filter(DocumentChunk.collection_name == collection_name)
  345. .first()
  346. is not None
  347. )
  348. return exists
  349. except Exception as e:
  350. log.exception(f"Error checking collection existence: {e}")
  351. return False
  352. def delete_collection(self, collection_name: str) -> None:
  353. self.delete(collection_name)
  354. log.info(f"Collection '{collection_name}' deleted.")