123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339 |
- from typing import Optional, List, Dict, Any
- from sqlalchemy import (
- cast,
- column,
- Column,
- Integer,
- select,
- text,
- Text,
- values,
- )
- from sqlalchemy.sql import true
- from sqlalchemy.orm import declarative_base, Session
- from sqlalchemy.dialects.postgresql import JSONB, array
- from pgvector.sqlalchemy import Vector
- from sqlalchemy.ext.mutable import MutableDict
- from sqlalchemy.ext.declarative import declarative_base
- from open_webui.apps.webui.internal.db import Session
- from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
- VECTOR_LENGTH = 1536
- Base = declarative_base()
- class DocumentChunk(Base):
- __tablename__ = "document_chunk"
- id = Column(Text, primary_key=True)
- vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
- collection_name = Column(Text, nullable=False)
- text = Column(Text, nullable=True)
- vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
- class PgvectorClient:
- def __init__(self) -> None:
- self.session = Session
- try:
- # Ensure the pgvector extension is available
- self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
- # Create the tables if they do not exist
- # Base.metadata.create_all requires a bind (engine or connection)
- # Get the connection from the session
- connection = self.session.connection()
- Base.metadata.create_all(bind=connection)
- # Create an index on the vector column if it doesn't exist
- self.session.execute(
- text(
- "CREATE INDEX IF NOT EXISTS idx_document_chunk_vector "
- "ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);"
- )
- )
- self.session.execute(
- text(
- "CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
- "ON document_chunk (collection_name);"
- )
- )
- self.session.commit()
- print("Initialization complete.")
- except Exception as e:
- self.session.rollback()
- print(f"Error during initialization: {e}")
- raise
- def adjust_vector_length(self, vector: List[float]) -> List[float]:
- # Adjust vector to have length VECTOR_LENGTH
- current_length = len(vector)
- if current_length < VECTOR_LENGTH:
- # Pad the vector with zeros
- vector += [0.0] * (VECTOR_LENGTH - current_length)
- elif current_length > VECTOR_LENGTH:
- raise Exception(
- f"Vector length {current_length} not supported. Max length must be <= {VECTOR_LENGTH}"
- )
- return vector
- def insert(self, collection_name: str, items: List[VectorItem]) -> None:
- try:
- new_items = []
- for item in items:
- vector = self.adjust_vector_length(item["vector"])
- new_chunk = DocumentChunk(
- id=item["id"],
- vector=vector,
- collection_name=collection_name,
- text=item["text"],
- vmetadata=item["metadata"],
- )
- new_items.append(new_chunk)
- self.session.bulk_save_objects(new_items)
- self.session.commit()
- print(
- f"Inserted {len(new_items)} items into collection '{collection_name}'."
- )
- except Exception as e:
- self.session.rollback()
- print(f"Error during insert: {e}")
- raise
- def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
- try:
- for item in items:
- vector = self.adjust_vector_length(item["vector"])
- existing = (
- self.session.query(DocumentChunk)
- .filter(DocumentChunk.id == item["id"])
- .first()
- )
- if existing:
- existing.vector = vector
- existing.text = item["text"]
- existing.vmetadata = item["metadata"]
- existing.collection_name = (
- collection_name # Update collection_name if necessary
- )
- else:
- new_chunk = DocumentChunk(
- id=item["id"],
- vector=vector,
- collection_name=collection_name,
- text=item["text"],
- vmetadata=item["metadata"],
- )
- self.session.add(new_chunk)
- self.session.commit()
- print(f"Upserted {len(items)} items into collection '{collection_name}'.")
- except Exception as e:
- self.session.rollback()
- print(f"Error during upsert: {e}")
- raise
- def search(
- self,
- collection_name: str,
- vectors: List[List[float]],
- limit: Optional[int] = None,
- ) -> Optional[SearchResult]:
- try:
- if not vectors:
- return None
- # Adjust query vectors to VECTOR_LENGTH
- vectors = [self.adjust_vector_length(vector) for vector in vectors]
- num_queries = len(vectors)
- def vector_expr(vector):
- return cast(array(vector), Vector(VECTOR_LENGTH))
- # Create the values for query vectors
- qid_col = column("qid", Integer)
- q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
- query_vectors = (
- values(qid_col, q_vector_col)
- .data(
- [(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]
- )
- .alias("query_vectors")
- )
- # Build the lateral subquery for each query vector
- subq = (
- select(
- DocumentChunk.id,
- DocumentChunk.text,
- DocumentChunk.vmetadata,
- (
- DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)
- ).label("distance"),
- )
- .where(DocumentChunk.collection_name == collection_name)
- .order_by(
- (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
- )
- )
- if limit is not None:
- subq = subq.limit(limit)
- subq = subq.lateral("result")
- # Build the main query by joining query_vectors and the lateral subquery
- stmt = (
- select(
- query_vectors.c.qid,
- subq.c.id,
- subq.c.text,
- subq.c.vmetadata,
- subq.c.distance,
- )
- .select_from(query_vectors)
- .join(subq, true())
- .order_by(query_vectors.c.qid, subq.c.distance)
- )
- result_proxy = self.session.execute(stmt)
- results = result_proxy.all()
- ids = [[] for _ in range(num_queries)]
- distances = [[] for _ in range(num_queries)]
- documents = [[] for _ in range(num_queries)]
- metadatas = [[] for _ in range(num_queries)]
- if not results:
- return SearchResult(
- ids=ids,
- distances=distances,
- documents=documents,
- metadatas=metadatas,
- )
- for row in results:
- qid = int(row.qid)
- ids[qid].append(row.id)
- distances[qid].append(row.distance)
- documents[qid].append(row.text)
- metadatas[qid].append(row.vmetadata)
- return SearchResult(
- ids=ids, distances=distances, documents=documents, metadatas=metadatas
- )
- except Exception as e:
- print(f"Error during search: {e}")
- return None
- def query(
- self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
- ) -> Optional[GetResult]:
- try:
- query = self.session.query(DocumentChunk).filter(
- DocumentChunk.collection_name == collection_name
- )
- for key, value in filter.items():
- query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
- if limit is not None:
- query = query.limit(limit)
- results = query.all()
- if not results:
- return None
- ids = [[result.id for result in results]]
- documents = [[result.text for result in results]]
- metadatas = [[result.vmetadata for result in results]]
- return GetResult(
- ids=ids,
- documents=documents,
- metadatas=metadatas,
- )
- except Exception as e:
- print(f"Error during query: {e}")
- return None
- def get(
- self, collection_name: str, limit: Optional[int] = None
- ) -> Optional[GetResult]:
- try:
- query = self.session.query(DocumentChunk).filter(
- DocumentChunk.collection_name == collection_name
- )
- if limit is not None:
- query = query.limit(limit)
- results = query.all()
- if not results:
- return None
- ids = [[result.id for result in results]]
- documents = [[result.text for result in results]]
- metadatas = [[result.vmetadata for result in results]]
- return GetResult(ids=ids, documents=documents, metadatas=metadatas)
- except Exception as e:
- print(f"Error during get: {e}")
- return None
- def delete(
- self,
- collection_name: str,
- ids: Optional[List[str]] = None,
- filter: Optional[Dict[str, Any]] = None,
- ) -> None:
- try:
- query = self.session.query(DocumentChunk).filter(
- DocumentChunk.collection_name == collection_name
- )
- if ids:
- query = query.filter(DocumentChunk.id.in_(ids))
- if filter:
- for key, value in filter.items():
- query = query.filter(
- DocumentChunk.vmetadata[key].astext == str(value)
- )
- deleted = query.delete(synchronize_session=False)
- self.session.commit()
- print(f"Deleted {deleted} items from collection '{collection_name}'.")
- except Exception as e:
- self.session.rollback()
- print(f"Error during delete: {e}")
- raise
- def reset(self) -> None:
- try:
- deleted = self.session.query(DocumentChunk).delete()
- self.session.commit()
- print(
- f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
- )
- except Exception as e:
- self.session.rollback()
- print(f"Error during reset: {e}")
- raise
- def close(self) -> None:
- pass
- def has_collection(self, collection_name: str) -> bool:
- try:
- exists = (
- self.session.query(DocumentChunk)
- .filter(DocumentChunk.collection_name == collection_name)
- .first()
- is not None
- )
- return exists
- except Exception as e:
- print(f"Error checking collection existence: {e}")
- return False
- def delete_collection(self, collection_name: str) -> None:
- self.delete(collection_name)
- print(f"Collection '{collection_name}' deleted.")
|