pgvector.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. from typing import Optional, List, Dict, Any
  2. from sqlalchemy import (
  3. cast,
  4. column,
  5. create_engine,
  6. Column,
  7. Integer,
  8. MetaData,
  9. select,
  10. text,
  11. Text,
  12. Table,
  13. values,
  14. )
  15. from sqlalchemy.sql import true
  16. from sqlalchemy.pool import NullPool
  17. from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
  18. from sqlalchemy.dialects.postgresql import JSONB, array
  19. from pgvector.sqlalchemy import Vector
  20. from sqlalchemy.ext.mutable import MutableDict
  21. from sqlalchemy.exc import NoSuchTableError
  22. from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
  23. from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
  24. VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
  25. Base = declarative_base()
  26. class DocumentChunk(Base):
  27. __tablename__ = "document_chunk"
  28. id = Column(Text, primary_key=True)
  29. vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
  30. collection_name = Column(Text, nullable=False)
  31. text = Column(Text, nullable=True)
  32. vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
  33. class PgvectorClient:
  34. def __init__(self) -> None:
  35. # if no pgvector uri, use the existing database connection
  36. if not PGVECTOR_DB_URL:
  37. from open_webui.internal.db import Session
  38. self.session = Session
  39. else:
  40. engine = create_engine(
  41. PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool
  42. )
  43. SessionLocal = sessionmaker(
  44. autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
  45. )
  46. self.session = scoped_session(SessionLocal)
  47. try:
  48. # Ensure the pgvector extension is available
  49. self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
  50. # Check vector length consistency
  51. self.check_vector_length()
  52. # Create the tables if they do not exist
  53. # Base.metadata.create_all requires a bind (engine or connection)
  54. # Get the connection from the session
  55. connection = self.session.connection()
  56. Base.metadata.create_all(bind=connection)
  57. # Create an index on the vector column if it doesn't exist
  58. self.session.execute(
  59. text(
  60. "CREATE INDEX IF NOT EXISTS idx_document_chunk_vector "
  61. "ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);"
  62. )
  63. )
  64. self.session.execute(
  65. text(
  66. "CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
  67. "ON document_chunk (collection_name);"
  68. )
  69. )
  70. self.session.commit()
  71. print("Initialization complete.")
  72. except Exception as e:
  73. self.session.rollback()
  74. print(f"Error during initialization: {e}")
  75. raise
  76. def check_vector_length(self) -> None:
  77. """
  78. Check if the VECTOR_LENGTH matches the existing vector column dimension in the database.
  79. Raises an exception if there is a mismatch.
  80. """
  81. metadata = MetaData()
  82. try:
  83. # Attempt to reflect the 'document_chunk' table
  84. document_chunk_table = Table(
  85. "document_chunk", metadata, autoload_with=self.session.bind
  86. )
  87. except NoSuchTableError:
  88. # Table does not exist; no action needed
  89. return
  90. # Proceed to check the vector column
  91. if "vector" in document_chunk_table.columns:
  92. vector_column = document_chunk_table.columns["vector"]
  93. vector_type = vector_column.type
  94. if isinstance(vector_type, Vector):
  95. db_vector_length = vector_type.dim
  96. if db_vector_length != VECTOR_LENGTH:
  97. raise Exception(
  98. f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. "
  99. "Cannot change vector size after initialization without migrating the data."
  100. )
  101. else:
  102. raise Exception(
  103. "The 'vector' column exists but is not of type 'Vector'."
  104. )
  105. else:
  106. raise Exception(
  107. "The 'vector' column does not exist in the 'document_chunk' table."
  108. )
  109. def adjust_vector_length(self, vector: List[float]) -> List[float]:
  110. # Adjust vector to have length VECTOR_LENGTH
  111. current_length = len(vector)
  112. if current_length < VECTOR_LENGTH:
  113. # Pad the vector with zeros
  114. vector += [0.0] * (VECTOR_LENGTH - current_length)
  115. elif current_length > VECTOR_LENGTH:
  116. raise Exception(
  117. f"Vector length {current_length} not supported. Max length must be <= {VECTOR_LENGTH}"
  118. )
  119. return vector
  120. def insert(self, collection_name: str, items: List[VectorItem]) -> None:
  121. try:
  122. new_items = []
  123. for item in items:
  124. vector = self.adjust_vector_length(item["vector"])
  125. new_chunk = DocumentChunk(
  126. id=item["id"],
  127. vector=vector,
  128. collection_name=collection_name,
  129. text=item["text"],
  130. vmetadata=item["metadata"],
  131. )
  132. new_items.append(new_chunk)
  133. self.session.bulk_save_objects(new_items)
  134. self.session.commit()
  135. print(
  136. f"Inserted {len(new_items)} items into collection '{collection_name}'."
  137. )
  138. except Exception as e:
  139. self.session.rollback()
  140. print(f"Error during insert: {e}")
  141. raise
  142. def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
  143. try:
  144. for item in items:
  145. vector = self.adjust_vector_length(item["vector"])
  146. existing = (
  147. self.session.query(DocumentChunk)
  148. .filter(DocumentChunk.id == item["id"])
  149. .first()
  150. )
  151. if existing:
  152. existing.vector = vector
  153. existing.text = item["text"]
  154. existing.vmetadata = item["metadata"]
  155. existing.collection_name = (
  156. collection_name # Update collection_name if necessary
  157. )
  158. else:
  159. new_chunk = DocumentChunk(
  160. id=item["id"],
  161. vector=vector,
  162. collection_name=collection_name,
  163. text=item["text"],
  164. vmetadata=item["metadata"],
  165. )
  166. self.session.add(new_chunk)
  167. self.session.commit()
  168. print(f"Upserted {len(items)} items into collection '{collection_name}'.")
  169. except Exception as e:
  170. self.session.rollback()
  171. print(f"Error during upsert: {e}")
  172. raise
  173. def search(
  174. self,
  175. collection_name: str,
  176. vectors: List[List[float]],
  177. limit: Optional[int] = None,
  178. ) -> Optional[SearchResult]:
  179. try:
  180. if not vectors:
  181. return None
  182. # Adjust query vectors to VECTOR_LENGTH
  183. vectors = [self.adjust_vector_length(vector) for vector in vectors]
  184. num_queries = len(vectors)
  185. def vector_expr(vector):
  186. return cast(array(vector), Vector(VECTOR_LENGTH))
  187. # Create the values for query vectors
  188. qid_col = column("qid", Integer)
  189. q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
  190. query_vectors = (
  191. values(qid_col, q_vector_col)
  192. .data(
  193. [(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]
  194. )
  195. .alias("query_vectors")
  196. )
  197. # Build the lateral subquery for each query vector
  198. subq = (
  199. select(
  200. DocumentChunk.id,
  201. DocumentChunk.text,
  202. DocumentChunk.vmetadata,
  203. (
  204. DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)
  205. ).label("distance"),
  206. )
  207. .where(DocumentChunk.collection_name == collection_name)
  208. .order_by(
  209. (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
  210. )
  211. )
  212. if limit is not None:
  213. subq = subq.limit(limit)
  214. subq = subq.lateral("result")
  215. # Build the main query by joining query_vectors and the lateral subquery
  216. stmt = (
  217. select(
  218. query_vectors.c.qid,
  219. subq.c.id,
  220. subq.c.text,
  221. subq.c.vmetadata,
  222. subq.c.distance,
  223. )
  224. .select_from(query_vectors)
  225. .join(subq, true())
  226. .order_by(query_vectors.c.qid, subq.c.distance)
  227. )
  228. result_proxy = self.session.execute(stmt)
  229. results = result_proxy.all()
  230. ids = [[] for _ in range(num_queries)]
  231. distances = [[] for _ in range(num_queries)]
  232. documents = [[] for _ in range(num_queries)]
  233. metadatas = [[] for _ in range(num_queries)]
  234. if not results:
  235. return SearchResult(
  236. ids=ids,
  237. distances=distances,
  238. documents=documents,
  239. metadatas=metadatas,
  240. )
  241. for row in results:
  242. qid = int(row.qid)
  243. ids[qid].append(row.id)
  244. distances[qid].append(row.distance)
  245. documents[qid].append(row.text)
  246. metadatas[qid].append(row.vmetadata)
  247. return SearchResult(
  248. ids=ids, distances=distances, documents=documents, metadatas=metadatas
  249. )
  250. except Exception as e:
  251. print(f"Error during search: {e}")
  252. return None
  253. def query(
  254. self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
  255. ) -> Optional[GetResult]:
  256. try:
  257. query = self.session.query(DocumentChunk).filter(
  258. DocumentChunk.collection_name == collection_name
  259. )
  260. for key, value in filter.items():
  261. query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
  262. if limit is not None:
  263. query = query.limit(limit)
  264. results = query.all()
  265. if not results:
  266. return None
  267. ids = [[result.id for result in results]]
  268. documents = [[result.text for result in results]]
  269. metadatas = [[result.vmetadata for result in results]]
  270. return GetResult(
  271. ids=ids,
  272. documents=documents,
  273. metadatas=metadatas,
  274. )
  275. except Exception as e:
  276. print(f"Error during query: {e}")
  277. return None
  278. def get(
  279. self, collection_name: str, limit: Optional[int] = None
  280. ) -> Optional[GetResult]:
  281. try:
  282. query = self.session.query(DocumentChunk).filter(
  283. DocumentChunk.collection_name == collection_name
  284. )
  285. if limit is not None:
  286. query = query.limit(limit)
  287. results = query.all()
  288. if not results:
  289. return None
  290. ids = [[result.id for result in results]]
  291. documents = [[result.text for result in results]]
  292. metadatas = [[result.vmetadata for result in results]]
  293. return GetResult(ids=ids, documents=documents, metadatas=metadatas)
  294. except Exception as e:
  295. print(f"Error during get: {e}")
  296. return None
  297. def delete(
  298. self,
  299. collection_name: str,
  300. ids: Optional[List[str]] = None,
  301. filter: Optional[Dict[str, Any]] = None,
  302. ) -> None:
  303. try:
  304. query = self.session.query(DocumentChunk).filter(
  305. DocumentChunk.collection_name == collection_name
  306. )
  307. if ids:
  308. query = query.filter(DocumentChunk.id.in_(ids))
  309. if filter:
  310. for key, value in filter.items():
  311. query = query.filter(
  312. DocumentChunk.vmetadata[key].astext == str(value)
  313. )
  314. deleted = query.delete(synchronize_session=False)
  315. self.session.commit()
  316. print(f"Deleted {deleted} items from collection '{collection_name}'.")
  317. except Exception as e:
  318. self.session.rollback()
  319. print(f"Error during delete: {e}")
  320. raise
  321. def reset(self) -> None:
  322. try:
  323. deleted = self.session.query(DocumentChunk).delete()
  324. self.session.commit()
  325. print(
  326. f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
  327. )
  328. except Exception as e:
  329. self.session.rollback()
  330. print(f"Error during reset: {e}")
  331. raise
  332. def close(self) -> None:
  333. pass
  334. def has_collection(self, collection_name: str) -> bool:
  335. try:
  336. exists = (
  337. self.session.query(DocumentChunk)
  338. .filter(DocumentChunk.collection_name == collection_name)
  339. .first()
  340. is not None
  341. )
  342. return exists
  343. except Exception as e:
  344. print(f"Error checking collection existence: {e}")
  345. return False
  346. def delete_collection(self, collection_name: str) -> None:
  347. self.delete(collection_name)
  348. print(f"Collection '{collection_name}' deleted.")