Browse Source

feat: Add ability to set URI for pgvector

Jason Kidd 6 months ago
parent
commit
319ea8cb7f

+ 17 - 4
backend/open_webui/apps/retrieval/vector/dbs/pgvector.py

@@ -2,6 +2,7 @@ from typing import Optional, List, Dict, Any
 from sqlalchemy import (
     cast,
     column,
+    create_engine,
     Column,
     Integer,
     select,
@@ -10,15 +11,15 @@ from sqlalchemy import (
     values,
 )
 from sqlalchemy.sql import true
+from sqlalchemy.pool import NullPool
 
-from sqlalchemy.orm import declarative_base, Session
+from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
 from sqlalchemy.dialects.postgresql import JSONB, array
 from pgvector.sqlalchemy import Vector
 from sqlalchemy.ext.mutable import MutableDict
-from sqlalchemy.ext.declarative import declarative_base
 
-from open_webui.apps.webui.internal.db import Session
 from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
+from open_webui.config import PGVECTOR_DB_URL
 
 VECTOR_LENGTH = 1536
 Base = declarative_base()
@@ -36,7 +37,19 @@ class DocumentChunk(Base):
 
 class PgvectorClient:
     def __init__(self) -> None:
-        self.session = Session
+
+        # if no pgvector uri, use the existing database connection
+        if not PGVECTOR_DB_URL:
+            from open_webui.apps.webui.internal.db import Session
+
+            self.session = Session
+        else:
+            engine = create_engine(PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool)
+            SessionLocal = sessionmaker(
+                autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
+            )
+            self.session = scoped_session(SessionLocal)
+
         try:
             # Ensure the pgvector extension is available
             self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))

+ 5 - 3
backend/open_webui/config.py

@@ -932,9 +932,6 @@ TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
 
 VECTOR_DB = os.environ.get("VECTOR_DB", "chroma")
 
-if VECTOR_DB == 'pgvector' and not DATABASE_URL.startswith("postgres"):
-    raise ValueError("Pgvector requires using Postgres with vector extension as the primary database.")
-
 # Chroma
 CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
 CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT)
@@ -968,6 +965,11 @@ OPENSEARCH_CERT_VERIFY = os.environ.get("OPENSEARCH_CERT_VERIFY", False)
 OPENSEARCH_USERNAME = os.environ.get("OPENSEARCH_USERNAME", None)
 OPENSEARCH_PASSWORD = os.environ.get("OPENSEARCH_PASSWORD", None)
 
+# Pgvector
+PGVECTOR_DB_URL = os.environ.get("PGVECTOR_DB_URL", None)
+if VECTOR_DB == 'pgvector' and not (DATABASE_URL.startswith("postgres") or PGVECTOR_DB_URL):
+    raise ValueError("Pgvector requires setting PGVECTOR_DB_URL or using Postgres with vector extension as the primary database.")
+
 ####################################
 # Information Retrieval (RAG)
 ####################################