pgvector.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. from typing import Optional, List, Dict, Any
  2. from sqlalchemy import (
  3. cast,
  4. column,
  5. create_engine,
  6. Column,
  7. Integer,
  8. select,
  9. text,
  10. Text,
  11. values,
  12. )
  13. from sqlalchemy.sql import true
  14. from sqlalchemy.pool import NullPool
  15. from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
  16. from sqlalchemy.dialects.postgresql import JSONB, array
  17. from pgvector.sqlalchemy import Vector
  18. from sqlalchemy.ext.mutable import MutableDict
  19. from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
  20. from open_webui.config import PGVECTOR_DB_URL
  21. VECTOR_LENGTH = 1536
  22. Base = declarative_base()
  23. class DocumentChunk(Base):
  24. __tablename__ = "document_chunk"
  25. id = Column(Text, primary_key=True)
  26. vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
  27. collection_name = Column(Text, nullable=False)
  28. text = Column(Text, nullable=True)
  29. vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
  30. class PgvectorClient:
  31. def __init__(self) -> None:
  32. # if no pgvector uri, use the existing database connection
  33. if not PGVECTOR_DB_URL:
  34. from open_webui.internal.db import Session
  35. self.session = Session
  36. else:
  37. engine = create_engine(
  38. PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool
  39. )
  40. SessionLocal = sessionmaker(
  41. autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
  42. )
  43. self.session = scoped_session(SessionLocal)
  44. try:
  45. # Ensure the pgvector extension is available
  46. self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
  47. # Create the tables if they do not exist
  48. # Base.metadata.create_all requires a bind (engine or connection)
  49. # Get the connection from the session
  50. connection = self.session.connection()
  51. Base.metadata.create_all(bind=connection)
  52. # Create an index on the vector column if it doesn't exist
  53. self.session.execute(
  54. text(
  55. "CREATE INDEX IF NOT EXISTS idx_document_chunk_vector "
  56. "ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);"
  57. )
  58. )
  59. self.session.execute(
  60. text(
  61. "CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
  62. "ON document_chunk (collection_name);"
  63. )
  64. )
  65. self.session.commit()
  66. print("Initialization complete.")
  67. except Exception as e:
  68. self.session.rollback()
  69. print(f"Error during initialization: {e}")
  70. raise
  71. def adjust_vector_length(self, vector: List[float]) -> List[float]:
  72. # Adjust vector to have length VECTOR_LENGTH
  73. current_length = len(vector)
  74. if current_length < VECTOR_LENGTH:
  75. # Pad the vector with zeros
  76. vector += [0.0] * (VECTOR_LENGTH - current_length)
  77. elif current_length > VECTOR_LENGTH:
  78. raise Exception(
  79. f"Vector length {current_length} not supported. Max length must be <= {VECTOR_LENGTH}"
  80. )
  81. return vector
  82. def insert(self, collection_name: str, items: List[VectorItem]) -> None:
  83. try:
  84. new_items = []
  85. for item in items:
  86. vector = self.adjust_vector_length(item["vector"])
  87. new_chunk = DocumentChunk(
  88. id=item["id"],
  89. vector=vector,
  90. collection_name=collection_name,
  91. text=item["text"],
  92. vmetadata=item["metadata"],
  93. )
  94. new_items.append(new_chunk)
  95. self.session.bulk_save_objects(new_items)
  96. self.session.commit()
  97. print(
  98. f"Inserted {len(new_items)} items into collection '{collection_name}'."
  99. )
  100. except Exception as e:
  101. self.session.rollback()
  102. print(f"Error during insert: {e}")
  103. raise
  104. def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
  105. try:
  106. for item in items:
  107. vector = self.adjust_vector_length(item["vector"])
  108. existing = (
  109. self.session.query(DocumentChunk)
  110. .filter(DocumentChunk.id == item["id"])
  111. .first()
  112. )
  113. if existing:
  114. existing.vector = vector
  115. existing.text = item["text"]
  116. existing.vmetadata = item["metadata"]
  117. existing.collection_name = (
  118. collection_name # Update collection_name if necessary
  119. )
  120. else:
  121. new_chunk = DocumentChunk(
  122. id=item["id"],
  123. vector=vector,
  124. collection_name=collection_name,
  125. text=item["text"],
  126. vmetadata=item["metadata"],
  127. )
  128. self.session.add(new_chunk)
  129. self.session.commit()
  130. print(f"Upserted {len(items)} items into collection '{collection_name}'.")
  131. except Exception as e:
  132. self.session.rollback()
  133. print(f"Error during upsert: {e}")
  134. raise
  135. def search(
  136. self,
  137. collection_name: str,
  138. vectors: List[List[float]],
  139. limit: Optional[int] = None,
  140. ) -> Optional[SearchResult]:
  141. try:
  142. if not vectors:
  143. return None
  144. # Adjust query vectors to VECTOR_LENGTH
  145. vectors = [self.adjust_vector_length(vector) for vector in vectors]
  146. num_queries = len(vectors)
  147. def vector_expr(vector):
  148. return cast(array(vector), Vector(VECTOR_LENGTH))
  149. # Create the values for query vectors
  150. qid_col = column("qid", Integer)
  151. q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
  152. query_vectors = (
  153. values(qid_col, q_vector_col)
  154. .data(
  155. [(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]
  156. )
  157. .alias("query_vectors")
  158. )
  159. # Build the lateral subquery for each query vector
  160. subq = (
  161. select(
  162. DocumentChunk.id,
  163. DocumentChunk.text,
  164. DocumentChunk.vmetadata,
  165. (
  166. DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)
  167. ).label("distance"),
  168. )
  169. .where(DocumentChunk.collection_name == collection_name)
  170. .order_by(
  171. (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
  172. )
  173. )
  174. if limit is not None:
  175. subq = subq.limit(limit)
  176. subq = subq.lateral("result")
  177. # Build the main query by joining query_vectors and the lateral subquery
  178. stmt = (
  179. select(
  180. query_vectors.c.qid,
  181. subq.c.id,
  182. subq.c.text,
  183. subq.c.vmetadata,
  184. subq.c.distance,
  185. )
  186. .select_from(query_vectors)
  187. .join(subq, true())
  188. .order_by(query_vectors.c.qid, subq.c.distance)
  189. )
  190. result_proxy = self.session.execute(stmt)
  191. results = result_proxy.all()
  192. ids = [[] for _ in range(num_queries)]
  193. distances = [[] for _ in range(num_queries)]
  194. documents = [[] for _ in range(num_queries)]
  195. metadatas = [[] for _ in range(num_queries)]
  196. if not results:
  197. return SearchResult(
  198. ids=ids,
  199. distances=distances,
  200. documents=documents,
  201. metadatas=metadatas,
  202. )
  203. for row in results:
  204. qid = int(row.qid)
  205. ids[qid].append(row.id)
  206. distances[qid].append(row.distance)
  207. documents[qid].append(row.text)
  208. metadatas[qid].append(row.vmetadata)
  209. return SearchResult(
  210. ids=ids, distances=distances, documents=documents, metadatas=metadatas
  211. )
  212. except Exception as e:
  213. print(f"Error during search: {e}")
  214. return None
  215. def query(
  216. self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
  217. ) -> Optional[GetResult]:
  218. try:
  219. query = self.session.query(DocumentChunk).filter(
  220. DocumentChunk.collection_name == collection_name
  221. )
  222. for key, value in filter.items():
  223. query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
  224. if limit is not None:
  225. query = query.limit(limit)
  226. results = query.all()
  227. if not results:
  228. return None
  229. ids = [[result.id for result in results]]
  230. documents = [[result.text for result in results]]
  231. metadatas = [[result.vmetadata for result in results]]
  232. return GetResult(
  233. ids=ids,
  234. documents=documents,
  235. metadatas=metadatas,
  236. )
  237. except Exception as e:
  238. print(f"Error during query: {e}")
  239. return None
  240. def get(
  241. self, collection_name: str, limit: Optional[int] = None
  242. ) -> Optional[GetResult]:
  243. try:
  244. query = self.session.query(DocumentChunk).filter(
  245. DocumentChunk.collection_name == collection_name
  246. )
  247. if limit is not None:
  248. query = query.limit(limit)
  249. results = query.all()
  250. if not results:
  251. return None
  252. ids = [[result.id for result in results]]
  253. documents = [[result.text for result in results]]
  254. metadatas = [[result.vmetadata for result in results]]
  255. return GetResult(ids=ids, documents=documents, metadatas=metadatas)
  256. except Exception as e:
  257. print(f"Error during get: {e}")
  258. return None
  259. def delete(
  260. self,
  261. collection_name: str,
  262. ids: Optional[List[str]] = None,
  263. filter: Optional[Dict[str, Any]] = None,
  264. ) -> None:
  265. try:
  266. query = self.session.query(DocumentChunk).filter(
  267. DocumentChunk.collection_name == collection_name
  268. )
  269. if ids:
  270. query = query.filter(DocumentChunk.id.in_(ids))
  271. if filter:
  272. for key, value in filter.items():
  273. query = query.filter(
  274. DocumentChunk.vmetadata[key].astext == str(value)
  275. )
  276. deleted = query.delete(synchronize_session=False)
  277. self.session.commit()
  278. print(f"Deleted {deleted} items from collection '{collection_name}'.")
  279. except Exception as e:
  280. self.session.rollback()
  281. print(f"Error during delete: {e}")
  282. raise
  283. def reset(self) -> None:
  284. try:
  285. deleted = self.session.query(DocumentChunk).delete()
  286. self.session.commit()
  287. print(
  288. f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
  289. )
  290. except Exception as e:
  291. self.session.rollback()
  292. print(f"Error during reset: {e}")
  293. raise
  294. def close(self) -> None:
  295. pass
  296. def has_collection(self, collection_name: str) -> bool:
  297. try:
  298. exists = (
  299. self.session.query(DocumentChunk)
  300. .filter(DocumentChunk.collection_name == collection_name)
  301. .first()
  302. is not None
  303. )
  304. return exists
  305. except Exception as e:
  306. print(f"Error checking collection existence: {e}")
  307. return False
  308. def delete_collection(self, collection_name: str) -> None:
  309. self.delete(collection_name)
  310. print(f"Collection '{collection_name}' deleted.")