Browse Source

feat: Initial support for pgvector

Jason Kidd 5 months ago
parent
commit
701f40aedd

+ 2 - 1
.dockerignore

@@ -16,4 +16,5 @@ _old
 uploads
 .ipynb_checkpoints
 **/*.db
-_test
+_test
+backend/data/*

+ 4 - 0
backend/open_webui/apps/retrieval/vector/connector.py

@@ -12,6 +12,10 @@ elif VECTOR_DB == "opensearch":
     from open_webui.apps.retrieval.vector.dbs.opensearch import OpenSearchClient
 
     VECTOR_DB_CLIENT = OpenSearchClient()
+elif VECTOR_DB == "pgvector":
+    from open_webui.apps.retrieval.vector.dbs.pgvector import PgvectorClient
+
+    VECTOR_DB_CLIENT = PgvectorClient()
 else:
     from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient
 

+ 339 - 0
backend/open_webui/apps/retrieval/vector/dbs/pgvector.py

@@ -0,0 +1,339 @@
+from typing import Optional, List, Dict, Any
+from sqlalchemy import (
+    cast,
+    column,
+    Column,
+    Integer,
+    select,
+    text,
+    Text,
+    values,
+)
+from sqlalchemy.sql import true
+
+from sqlalchemy.orm import declarative_base, Session
+from sqlalchemy.dialects.postgresql import JSONB, array
+from pgvector.sqlalchemy import Vector
+from sqlalchemy.ext.mutable import MutableDict
+from sqlalchemy.ext.declarative import declarative_base
+
+from open_webui.apps.webui.internal.db import Session
+from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
+
+VECTOR_LENGTH = 1536
+Base = declarative_base()
+
+
+class DocumentChunk(Base):
+    __tablename__ = "document_chunk"
+
+    id = Column(Text, primary_key=True)
+    vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
+    collection_name = Column(Text, nullable=False)
+    text = Column(Text, nullable=True)
+    vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
+
+
+class PgvectorClient:
+    def __init__(self) -> None:
+        self.session = Session
+        try:
+            # Ensure the pgvector extension is available
+            self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
+
+            # Create the tables if they do not exist
+            # Base.metadata.create_all requires a bind (engine or connection)
+            # Get the connection from the session
+            connection = self.session.connection()
+            Base.metadata.create_all(bind=connection)
+
+            # Create an index on the vector column if it doesn't exist
+            self.session.execute(
+                text(
+                    "CREATE INDEX IF NOT EXISTS idx_document_chunk_vector "
+                    "ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);"
+                )
+            )
+            self.session.execute(
+                text(
+                    "CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
+                    "ON document_chunk (collection_name);"
+                )
+            )
+            self.session.commit()
+            print("Initialization complete.")
+        except Exception as e:
+            self.session.rollback()
+            print(f"Error during initialization: {e}")
+            raise
+
+    def adjust_vector_length(self, vector: List[float]) -> List[float]:
+        # Adjust vector to have length VECTOR_LENGTH
+        current_length = len(vector)
+        if current_length < VECTOR_LENGTH:
+            # Pad the vector with zeros
+            vector += [0.0] * (VECTOR_LENGTH - current_length)
+        elif current_length > VECTOR_LENGTH:
+            raise Exception(
+                f"Vector length {current_length} not supported. Max length must be <= {VECTOR_LENGTH}"
+            )
+        return vector
+
+    def insert(self, collection_name: str, items: List[VectorItem]) -> None:
+        try:
+            new_items = []
+            for item in items:
+                vector = self.adjust_vector_length(item["vector"])
+                new_chunk = DocumentChunk(
+                    id=item["id"],
+                    vector=vector,
+                    collection_name=collection_name,
+                    text=item["text"],
+                    vmetadata=item["metadata"],
+                )
+                new_items.append(new_chunk)
+            self.session.bulk_save_objects(new_items)
+            self.session.commit()
+            print(
+                f"Inserted {len(new_items)} items into collection '{collection_name}'."
+            )
+        except Exception as e:
+            self.session.rollback()
+            print(f"Error during insert: {e}")
+            raise
+
+    def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
+        try:
+            for item in items:
+                vector = self.adjust_vector_length(item["vector"])
+                existing = (
+                    self.session.query(DocumentChunk)
+                    .filter(DocumentChunk.id == item["id"])
+                    .first()
+                )
+                if existing:
+                    existing.vector = vector
+                    existing.text = item["text"]
+                    existing.vmetadata = item["metadata"]
+                    existing.collection_name = (
+                        collection_name  # Update collection_name if necessary
+                    )
+                else:
+                    new_chunk = DocumentChunk(
+                        id=item["id"],
+                        vector=vector,
+                        collection_name=collection_name,
+                        text=item["text"],
+                        vmetadata=item["metadata"],
+                    )
+                    self.session.add(new_chunk)
+            self.session.commit()
+            print(f"Upserted {len(items)} items into collection '{collection_name}'.")
+        except Exception as e:
+            self.session.rollback()
+            print(f"Error during upsert: {e}")
+            raise
+
+    def search(
+        self,
+        collection_name: str,
+        vectors: List[List[float]],
+        limit: Optional[int] = None,
+    ) -> Optional[SearchResult]:
+        try:
+            if not vectors:
+                return None
+
+            # Adjust query vectors to VECTOR_LENGTH
+            vectors = [self.adjust_vector_length(vector) for vector in vectors]
+            num_queries = len(vectors)
+
+            def vector_expr(vector):
+                return cast(array(vector), Vector(VECTOR_LENGTH))
+
+            # Create the values for query vectors
+            qid_col = column("qid", Integer)
+            q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
+            query_vectors = (
+                values(qid_col, q_vector_col)
+                .data(
+                    [(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]
+                )
+                .alias("query_vectors")
+            )
+
+            # Build the lateral subquery for each query vector
+            subq = (
+                select(
+                    DocumentChunk.id,
+                    DocumentChunk.text,
+                    DocumentChunk.vmetadata,
+                    (
+                        DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)
+                    ).label("distance"),
+                )
+                .where(DocumentChunk.collection_name == collection_name)
+                .order_by(
+                    (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
+                )
+            )
+            if limit is not None:
+                subq = subq.limit(limit)
+            subq = subq.lateral("result")
+
+            # Build the main query by joining query_vectors and the lateral subquery
+            stmt = (
+                select(
+                    query_vectors.c.qid,
+                    subq.c.id,
+                    subq.c.text,
+                    subq.c.vmetadata,
+                    subq.c.distance,
+                )
+                .select_from(query_vectors)
+                .join(subq, true())
+                .order_by(query_vectors.c.qid, subq.c.distance)
+            )
+
+            result_proxy = self.session.execute(stmt)
+            results = result_proxy.all()
+
+            ids = [[] for _ in range(num_queries)]
+            distances = [[] for _ in range(num_queries)]
+            documents = [[] for _ in range(num_queries)]
+            metadatas = [[] for _ in range(num_queries)]
+
+            if not results:
+                return SearchResult(
+                    ids=ids,
+                    distances=distances,
+                    documents=documents,
+                    metadatas=metadatas,
+                )
+
+            for row in results:
+                qid = int(row.qid)
+                ids[qid].append(row.id)
+                distances[qid].append(row.distance)
+                documents[qid].append(row.text)
+                metadatas[qid].append(row.vmetadata)
+
+            return SearchResult(
+                ids=ids, distances=distances, documents=documents, metadatas=metadatas
+            )
+        except Exception as e:
+            print(f"Error during search: {e}")
+            return None
+
+    def query(
+        self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
+    ) -> Optional[GetResult]:
+        try:
+            query = self.session.query(DocumentChunk).filter(
+                DocumentChunk.collection_name == collection_name
+            )
+
+            for key, value in filter.items():
+                query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
+
+            if limit is not None:
+                query = query.limit(limit)
+
+            results = query.all()
+
+            if not results:
+                return None
+
+            ids = [[result.id for result in results]]
+            documents = [[result.text for result in results]]
+            metadatas = [[result.vmetadata for result in results]]
+
+            return GetResult(
+                ids=ids,
+                documents=documents,
+                metadatas=metadatas,
+            )
+        except Exception as e:
+            print(f"Error during query: {e}")
+            return None
+
+    def get(
+        self, collection_name: str, limit: Optional[int] = None
+    ) -> Optional[GetResult]:
+        try:
+            query = self.session.query(DocumentChunk).filter(
+                DocumentChunk.collection_name == collection_name
+            )
+            if limit is not None:
+                query = query.limit(limit)
+
+            results = query.all()
+
+            if not results:
+                return None
+
+            ids = [[result.id for result in results]]
+            documents = [[result.text for result in results]]
+            metadatas = [[result.vmetadata for result in results]]
+
+            return GetResult(ids=ids, documents=documents, metadatas=metadatas)
+        except Exception as e:
+            print(f"Error during get: {e}")
+            return None
+
+    def delete(
+        self,
+        collection_name: str,
+        ids: Optional[List[str]] = None,
+        filter: Optional[Dict[str, Any]] = None,
+    ) -> None:
+        try:
+            query = self.session.query(DocumentChunk).filter(
+                DocumentChunk.collection_name == collection_name
+            )
+            if ids:
+                query = query.filter(DocumentChunk.id.in_(ids))
+            if filter:
+                for key, value in filter.items():
+                    query = query.filter(
+                        DocumentChunk.vmetadata[key].astext == str(value)
+                    )
+            deleted = query.delete(synchronize_session=False)
+            self.session.commit()
+            print(f"Deleted {deleted} items from collection '{collection_name}'.")
+        except Exception as e:
+            self.session.rollback()
+            print(f"Error during delete: {e}")
+            raise
+
+    def reset(self) -> None:
+        try:
+            deleted = self.session.query(DocumentChunk).delete()
+            self.session.commit()
+            print(
+                f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
+            )
+        except Exception as e:
+            self.session.rollback()
+            print(f"Error during reset: {e}")
+            raise
+
+    def close(self) -> None:
+        pass
+
+    def has_collection(self, collection_name: str) -> bool:
+        try:
+            exists = (
+                self.session.query(DocumentChunk)
+                .filter(DocumentChunk.collection_name == collection_name)
+                .first()
+                is not None
+            )
+            return exists
+        except Exception as e:
+            print(f"Error checking collection existence: {e}")
+            return False
+
+    def delete_collection(self, collection_name: str) -> None:
+        self.delete(collection_name)
+        print(f"Collection '{collection_name}' deleted.")

+ 4 - 0
backend/open_webui/config.py

@@ -20,6 +20,7 @@ from open_webui.env import (
     WEBUI_FAVICON_URL,
     WEBUI_NAME,
     log,
+    DATABASE_URL,
 )
 from pydantic import BaseModel
 from sqlalchemy import JSON, Column, DateTime, Integer, func
@@ -931,6 +932,9 @@ TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
 
 VECTOR_DB = os.environ.get("VECTOR_DB", "chroma")
 
+if VECTOR_DB == 'pgvector' and not DATABASE_URL.startswith("postgres"):
+    raise ValueError("Pgvector requires using Postgres with vector extension as the primary database.")
+
 # Chroma
 CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
 CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT)

+ 1 - 0
backend/requirements.txt

@@ -19,6 +19,7 @@ alembic==1.13.2
 peewee==3.17.6
 peewee-migrate==1.12.2
 psycopg2-binary==2.9.9
+pgvector==0.3.5
 PyMySQL==1.1.1
 bcrypt==4.2.0