|
@@ -5,6 +5,7 @@ from sqlalchemy import (
|
|
create_engine,
|
|
create_engine,
|
|
Column,
|
|
Column,
|
|
Integer,
|
|
Integer,
|
|
|
|
+ MetaData,
|
|
select,
|
|
select,
|
|
text,
|
|
text,
|
|
Text,
|
|
Text,
|
|
@@ -19,9 +20,9 @@ from pgvector.sqlalchemy import Vector
|
|
from sqlalchemy.ext.mutable import MutableDict
|
|
from sqlalchemy.ext.mutable import MutableDict
|
|
|
|
|
|
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
|
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
|
-from open_webui.config import PGVECTOR_DB_URL
|
|
|
|
|
|
+from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
|
|
|
|
|
-VECTOR_LENGTH = 1536
|
|
|
|
|
|
+VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
|
Base = declarative_base()
|
|
Base = declarative_base()
|
|
|
|
|
|
|
|
|
|
@@ -56,6 +57,9 @@ class PgvectorClient:
|
|
# Ensure the pgvector extension is available
|
|
# Ensure the pgvector extension is available
|
|
self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
|
|
self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
|
|
|
|
|
|
|
|
+ # Check vector length consistency
|
|
|
|
+ self.check_vector_length()
|
|
|
|
+
|
|
# Create the tables if they do not exist
|
|
# Create the tables if they do not exist
|
|
# Base.metadata.create_all requires a bind (engine or connection)
|
|
# Base.metadata.create_all requires a bind (engine or connection)
|
|
# Get the connection from the session
|
|
# Get the connection from the session
|
|
@@ -82,6 +86,38 @@ class PgvectorClient:
|
|
print(f"Error during initialization: {e}")
|
|
print(f"Error during initialization: {e}")
|
|
raise
|
|
raise
|
|
|
|
|
|
|
|
+ def check_vector_length(self) -> None:
|
|
|
|
+ """
|
|
|
|
+ Check if the VECTOR_LENGTH matches the existing vector column dimension in the database.
|
|
|
|
+ Raises an exception if there is a mismatch.
|
|
|
|
+ """
|
|
|
|
+ metadata = MetaData()
|
|
|
|
+ metadata.reflect(bind=self.session.bind, only=["document_chunk"])
|
|
|
|
+
|
|
|
|
+ if "document_chunk" in metadata.tables:
|
|
|
|
+ document_chunk_table = metadata.tables["document_chunk"]
|
|
|
|
+ if "vector" in document_chunk_table.columns:
|
|
|
|
+ vector_column = document_chunk_table.columns["vector"]
|
|
|
|
+ vector_type = vector_column.type
|
|
|
|
+ if isinstance(vector_type, Vector):
|
|
|
|
+ db_vector_length = vector_type.dim
|
|
|
|
+ if db_vector_length != VECTOR_LENGTH:
|
|
|
|
+ raise Exception(
|
|
|
|
+ f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. "
|
|
|
|
+ "Cannot change vector size after initialization without migrating the data."
|
|
|
|
+ )
|
|
|
|
+ else:
|
|
|
|
+ raise Exception(
|
|
|
|
+ "The 'vector' column exists but is not of type 'Vector'."
|
|
|
|
+ )
|
|
|
|
+ else:
|
|
|
|
+ raise Exception(
|
|
|
|
+ "The 'vector' column does not exist in the 'document_chunk' table."
|
|
|
|
+ )
|
|
|
|
+ else:
|
|
|
|
+ # Table does not exist yet; no action needed
|
|
|
|
+ pass
|
|
|
|
+
|
|
def adjust_vector_length(self, vector: List[float]) -> List[float]:
|
|
def adjust_vector_length(self, vector: List[float]) -> List[float]:
|
|
# Adjust vector to have length VECTOR_LENGTH
|
|
# Adjust vector to have length VECTOR_LENGTH
|
|
current_length = len(vector)
|
|
current_length = len(vector)
|