Procházet zdrojové kódy

Merge pull request #8298 from jk-f5/feat/pg_vector_size

feat: Allow setting the initial vector length on pgvector document_chunk table
Timothy Jaeryang Baek před 4 měsíci
rodič
revize
0e805e7dc4

+ 3 - 0
backend/open_webui/config.py

@@ -1211,6 +1211,9 @@ if VECTOR_DB == "pgvector" and not PGVECTOR_DB_URL.startswith("postgres"):
     raise ValueError(
     raise ValueError(
         "Pgvector requires setting PGVECTOR_DB_URL or using Postgres with vector extension as the primary database."
         "Pgvector requires setting PGVECTOR_DB_URL or using Postgres with vector extension as the primary database."
     )
     )
+PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH = int(
+    os.environ.get("PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH", "1536")
+)
 
 
 ####################################
 ####################################
 # Information Retrieval (RAG)
 # Information Retrieval (RAG)

+ 38 - 2
backend/open_webui/retrieval/vector/dbs/pgvector.py

@@ -5,6 +5,7 @@ from sqlalchemy import (
     create_engine,
     create_engine,
     Column,
     Column,
     Integer,
     Integer,
+    MetaData,
     select,
     select,
     text,
     text,
     Text,
     Text,
@@ -19,9 +20,9 @@ from pgvector.sqlalchemy import Vector
 from sqlalchemy.ext.mutable import MutableDict
 from sqlalchemy.ext.mutable import MutableDict
 
 
 from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
 from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
-from open_webui.config import PGVECTOR_DB_URL
+from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
 
 
-VECTOR_LENGTH = 1536
+VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
 Base = declarative_base()
 Base = declarative_base()
 
 
 
 
@@ -56,6 +57,9 @@ class PgvectorClient:
             # Ensure the pgvector extension is available
             # Ensure the pgvector extension is available
             self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
             self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
 
 
+            # Check vector length consistency
+            self.check_vector_length()
+
             # Create the tables if they do not exist
             # Create the tables if they do not exist
             # Base.metadata.create_all requires a bind (engine or connection)
             # Base.metadata.create_all requires a bind (engine or connection)
             # Get the connection from the session
             # Get the connection from the session
@@ -82,6 +86,38 @@ class PgvectorClient:
             print(f"Error during initialization: {e}")
             print(f"Error during initialization: {e}")
             raise
             raise
 
 
+    def check_vector_length(self) -> None:
+        """
+        Check if the VECTOR_LENGTH matches the existing vector column dimension in the database.
+        Raises an exception if there is a mismatch.
+        """
+        metadata = MetaData()
+        metadata.reflect(bind=self.session.bind, only=["document_chunk"])
+
+        if "document_chunk" in metadata.tables:
+            document_chunk_table = metadata.tables["document_chunk"]
+            if "vector" in document_chunk_table.columns:
+                vector_column = document_chunk_table.columns["vector"]
+                vector_type = vector_column.type
+                if isinstance(vector_type, Vector):
+                    db_vector_length = vector_type.dim
+                    if db_vector_length != VECTOR_LENGTH:
+                        raise Exception(
+                            f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. "
+                            "Cannot change vector size after initialization without migrating the data."
+                        )
+                else:
+                    raise Exception(
+                        "The 'vector' column exists but is not of type 'Vector'."
+                    )
+            else:
+                raise Exception(
+                    "The 'vector' column does not exist in the 'document_chunk' table."
+                )
+        else:
+            # Table does not exist yet; no action needed
+            pass
+
     def adjust_vector_length(self, vector: List[float]) -> List[float]:
     def adjust_vector_length(self, vector: List[float]) -> List[float]:
         # Adjust vector to have length VECTOR_LENGTH
         # Adjust vector to have length VECTOR_LENGTH
         current_length = len(vector)
         current_length = len(vector)