pgvector.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. from typing import Optional, List, Dict, Any
  2. from sqlalchemy import (
  3. cast,
  4. column,
  5. create_engine,
  6. Column,
  7. Integer,
  8. MetaData,
  9. select,
  10. text,
  11. Text,
  12. values,
  13. )
  14. from sqlalchemy.sql import true
  15. from sqlalchemy.pool import NullPool
  16. from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
  17. from sqlalchemy.dialects.postgresql import JSONB, array
  18. from pgvector.sqlalchemy import Vector
  19. from sqlalchemy.ext.mutable import MutableDict
  20. from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
  21. from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
  22. VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
  23. Base = declarative_base()
  24. class DocumentChunk(Base):
  25. __tablename__ = "document_chunk"
  26. id = Column(Text, primary_key=True)
  27. vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
  28. collection_name = Column(Text, nullable=False)
  29. text = Column(Text, nullable=True)
  30. vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
  31. class PgvectorClient:
  32. def __init__(self) -> None:
  33. # if no pgvector uri, use the existing database connection
  34. if not PGVECTOR_DB_URL:
  35. from open_webui.internal.db import Session
  36. self.session = Session
  37. else:
  38. engine = create_engine(
  39. PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool
  40. )
  41. SessionLocal = sessionmaker(
  42. autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
  43. )
  44. self.session = scoped_session(SessionLocal)
  45. try:
  46. # Ensure the pgvector extension is available
  47. self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
  48. # Check vector length consistency
  49. self.check_vector_length()
  50. # Create the tables if they do not exist
  51. # Base.metadata.create_all requires a bind (engine or connection)
  52. # Get the connection from the session
  53. connection = self.session.connection()
  54. Base.metadata.create_all(bind=connection)
  55. # Create an index on the vector column if it doesn't exist
  56. self.session.execute(
  57. text(
  58. "CREATE INDEX IF NOT EXISTS idx_document_chunk_vector "
  59. "ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);"
  60. )
  61. )
  62. self.session.execute(
  63. text(
  64. "CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
  65. "ON document_chunk (collection_name);"
  66. )
  67. )
  68. self.session.commit()
  69. print("Initialization complete.")
  70. except Exception as e:
  71. self.session.rollback()
  72. print(f"Error during initialization: {e}")
  73. raise
  74. def check_vector_length(self) -> None:
  75. """
  76. Check if the VECTOR_LENGTH matches the existing vector column dimension in the database.
  77. Raises an exception if there is a mismatch.
  78. """
  79. metadata = MetaData()
  80. metadata.reflect(bind=self.session.bind, only=["document_chunk"])
  81. if "document_chunk" in metadata.tables:
  82. document_chunk_table = metadata.tables["document_chunk"]
  83. if "vector" in document_chunk_table.columns:
  84. vector_column = document_chunk_table.columns["vector"]
  85. vector_type = vector_column.type
  86. if isinstance(vector_type, Vector):
  87. db_vector_length = vector_type.dim
  88. if db_vector_length != VECTOR_LENGTH:
  89. raise Exception(
  90. f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. "
  91. "Cannot change vector size after initialization without migrating the data."
  92. )
  93. else:
  94. raise Exception(
  95. "The 'vector' column exists but is not of type 'Vector'."
  96. )
  97. else:
  98. raise Exception(
  99. "The 'vector' column does not exist in the 'document_chunk' table."
  100. )
  101. else:
  102. # Table does not exist yet; no action needed
  103. pass
  104. def adjust_vector_length(self, vector: List[float]) -> List[float]:
  105. # Adjust vector to have length VECTOR_LENGTH
  106. current_length = len(vector)
  107. if current_length < VECTOR_LENGTH:
  108. # Pad the vector with zeros
  109. vector += [0.0] * (VECTOR_LENGTH - current_length)
  110. elif current_length > VECTOR_LENGTH:
  111. raise Exception(
  112. f"Vector length {current_length} not supported. Max length must be <= {VECTOR_LENGTH}"
  113. )
  114. return vector
  115. def insert(self, collection_name: str, items: List[VectorItem]) -> None:
  116. try:
  117. new_items = []
  118. for item in items:
  119. vector = self.adjust_vector_length(item["vector"])
  120. new_chunk = DocumentChunk(
  121. id=item["id"],
  122. vector=vector,
  123. collection_name=collection_name,
  124. text=item["text"],
  125. vmetadata=item["metadata"],
  126. )
  127. new_items.append(new_chunk)
  128. self.session.bulk_save_objects(new_items)
  129. self.session.commit()
  130. print(
  131. f"Inserted {len(new_items)} items into collection '{collection_name}'."
  132. )
  133. except Exception as e:
  134. self.session.rollback()
  135. print(f"Error during insert: {e}")
  136. raise
  137. def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
  138. try:
  139. for item in items:
  140. vector = self.adjust_vector_length(item["vector"])
  141. existing = (
  142. self.session.query(DocumentChunk)
  143. .filter(DocumentChunk.id == item["id"])
  144. .first()
  145. )
  146. if existing:
  147. existing.vector = vector
  148. existing.text = item["text"]
  149. existing.vmetadata = item["metadata"]
  150. existing.collection_name = (
  151. collection_name # Update collection_name if necessary
  152. )
  153. else:
  154. new_chunk = DocumentChunk(
  155. id=item["id"],
  156. vector=vector,
  157. collection_name=collection_name,
  158. text=item["text"],
  159. vmetadata=item["metadata"],
  160. )
  161. self.session.add(new_chunk)
  162. self.session.commit()
  163. print(f"Upserted {len(items)} items into collection '{collection_name}'.")
  164. except Exception as e:
  165. self.session.rollback()
  166. print(f"Error during upsert: {e}")
  167. raise
  168. def search(
  169. self,
  170. collection_name: str,
  171. vectors: List[List[float]],
  172. limit: Optional[int] = None,
  173. ) -> Optional[SearchResult]:
  174. try:
  175. if not vectors:
  176. return None
  177. # Adjust query vectors to VECTOR_LENGTH
  178. vectors = [self.adjust_vector_length(vector) for vector in vectors]
  179. num_queries = len(vectors)
  180. def vector_expr(vector):
  181. return cast(array(vector), Vector(VECTOR_LENGTH))
  182. # Create the values for query vectors
  183. qid_col = column("qid", Integer)
  184. q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
  185. query_vectors = (
  186. values(qid_col, q_vector_col)
  187. .data(
  188. [(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]
  189. )
  190. .alias("query_vectors")
  191. )
  192. # Build the lateral subquery for each query vector
  193. subq = (
  194. select(
  195. DocumentChunk.id,
  196. DocumentChunk.text,
  197. DocumentChunk.vmetadata,
  198. (
  199. DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)
  200. ).label("distance"),
  201. )
  202. .where(DocumentChunk.collection_name == collection_name)
  203. .order_by(
  204. (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
  205. )
  206. )
  207. if limit is not None:
  208. subq = subq.limit(limit)
  209. subq = subq.lateral("result")
  210. # Build the main query by joining query_vectors and the lateral subquery
  211. stmt = (
  212. select(
  213. query_vectors.c.qid,
  214. subq.c.id,
  215. subq.c.text,
  216. subq.c.vmetadata,
  217. subq.c.distance,
  218. )
  219. .select_from(query_vectors)
  220. .join(subq, true())
  221. .order_by(query_vectors.c.qid, subq.c.distance)
  222. )
  223. result_proxy = self.session.execute(stmt)
  224. results = result_proxy.all()
  225. ids = [[] for _ in range(num_queries)]
  226. distances = [[] for _ in range(num_queries)]
  227. documents = [[] for _ in range(num_queries)]
  228. metadatas = [[] for _ in range(num_queries)]
  229. if not results:
  230. return SearchResult(
  231. ids=ids,
  232. distances=distances,
  233. documents=documents,
  234. metadatas=metadatas,
  235. )
  236. for row in results:
  237. qid = int(row.qid)
  238. ids[qid].append(row.id)
  239. distances[qid].append(row.distance)
  240. documents[qid].append(row.text)
  241. metadatas[qid].append(row.vmetadata)
  242. return SearchResult(
  243. ids=ids, distances=distances, documents=documents, metadatas=metadatas
  244. )
  245. except Exception as e:
  246. print(f"Error during search: {e}")
  247. return None
  248. def query(
  249. self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
  250. ) -> Optional[GetResult]:
  251. try:
  252. query = self.session.query(DocumentChunk).filter(
  253. DocumentChunk.collection_name == collection_name
  254. )
  255. for key, value in filter.items():
  256. query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
  257. if limit is not None:
  258. query = query.limit(limit)
  259. results = query.all()
  260. if not results:
  261. return None
  262. ids = [[result.id for result in results]]
  263. documents = [[result.text for result in results]]
  264. metadatas = [[result.vmetadata for result in results]]
  265. return GetResult(
  266. ids=ids,
  267. documents=documents,
  268. metadatas=metadatas,
  269. )
  270. except Exception as e:
  271. print(f"Error during query: {e}")
  272. return None
  273. def get(
  274. self, collection_name: str, limit: Optional[int] = None
  275. ) -> Optional[GetResult]:
  276. try:
  277. query = self.session.query(DocumentChunk).filter(
  278. DocumentChunk.collection_name == collection_name
  279. )
  280. if limit is not None:
  281. query = query.limit(limit)
  282. results = query.all()
  283. if not results:
  284. return None
  285. ids = [[result.id for result in results]]
  286. documents = [[result.text for result in results]]
  287. metadatas = [[result.vmetadata for result in results]]
  288. return GetResult(ids=ids, documents=documents, metadatas=metadatas)
  289. except Exception as e:
  290. print(f"Error during get: {e}")
  291. return None
  292. def delete(
  293. self,
  294. collection_name: str,
  295. ids: Optional[List[str]] = None,
  296. filter: Optional[Dict[str, Any]] = None,
  297. ) -> None:
  298. try:
  299. query = self.session.query(DocumentChunk).filter(
  300. DocumentChunk.collection_name == collection_name
  301. )
  302. if ids:
  303. query = query.filter(DocumentChunk.id.in_(ids))
  304. if filter:
  305. for key, value in filter.items():
  306. query = query.filter(
  307. DocumentChunk.vmetadata[key].astext == str(value)
  308. )
  309. deleted = query.delete(synchronize_session=False)
  310. self.session.commit()
  311. print(f"Deleted {deleted} items from collection '{collection_name}'.")
  312. except Exception as e:
  313. self.session.rollback()
  314. print(f"Error during delete: {e}")
  315. raise
  316. def reset(self) -> None:
  317. try:
  318. deleted = self.session.query(DocumentChunk).delete()
  319. self.session.commit()
  320. print(
  321. f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
  322. )
  323. except Exception as e:
  324. self.session.rollback()
  325. print(f"Error during reset: {e}")
  326. raise
  327. def close(self) -> None:
  328. pass
  329. def has_collection(self, collection_name: str) -> bool:
  330. try:
  331. exists = (
  332. self.session.query(DocumentChunk)
  333. .filter(DocumentChunk.collection_name == collection_name)
  334. .first()
  335. is not None
  336. )
  337. return exists
  338. except Exception as e:
  339. print(f"Error checking collection existence: {e}")
  340. return False
  341. def delete_collection(self, collection_name: str) -> None:
  342. self.delete(collection_name)
  343. print(f"Collection '{collection_name}' deleted.")