|
@@ -2,6 +2,7 @@ from typing import Optional, List, Dict, Any
|
|
from sqlalchemy import (
|
|
from sqlalchemy import (
|
|
cast,
|
|
cast,
|
|
column,
|
|
column,
|
|
|
|
+ create_engine,
|
|
Column,
|
|
Column,
|
|
Integer,
|
|
Integer,
|
|
select,
|
|
select,
|
|
@@ -10,15 +11,15 @@ from sqlalchemy import (
|
|
values,
|
|
values,
|
|
)
|
|
)
|
|
from sqlalchemy.sql import true
|
|
from sqlalchemy.sql import true
|
|
|
|
+from sqlalchemy.pool import NullPool
|
|
|
|
|
|
-from sqlalchemy.orm import declarative_base, Session
|
|
|
|
|
|
+from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
|
|
from sqlalchemy.dialects.postgresql import JSONB, array
|
|
from sqlalchemy.dialects.postgresql import JSONB, array
|
|
from pgvector.sqlalchemy import Vector
|
|
from pgvector.sqlalchemy import Vector
|
|
from sqlalchemy.ext.mutable import MutableDict
|
|
from sqlalchemy.ext.mutable import MutableDict
|
|
-from sqlalchemy.ext.declarative import declarative_base
|
|
|
|
|
|
|
|
-from open_webui.apps.webui.internal.db import Session
|
|
|
|
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
|
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
|
|
|
+from open_webui.config import PGVECTOR_DB_URL
|
|
|
|
|
|
VECTOR_LENGTH = 1536
|
|
VECTOR_LENGTH = 1536
|
|
Base = declarative_base()
|
|
Base = declarative_base()
|
|
@@ -36,7 +37,19 @@ class DocumentChunk(Base):
|
|
|
|
|
|
class PgvectorClient:
|
|
class PgvectorClient:
|
|
def __init__(self) -> None:
|
|
def __init__(self) -> None:
|
|
- self.session = Session
|
|
|
|
|
|
+
|
|
|
|
+ # if no pgvector uri, use the existing database connection
|
|
|
|
+ if not PGVECTOR_DB_URL:
|
|
|
|
+ from open_webui.apps.webui.internal.db import Session
|
|
|
|
+
|
|
|
|
+ self.session = Session
|
|
|
|
+ else:
|
|
|
|
+ engine = create_engine(PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool)
|
|
|
|
+ SessionLocal = sessionmaker(
|
|
|
|
+ autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
|
|
|
|
+ )
|
|
|
|
+ self.session = scoped_session(SessionLocal)
|
|
|
|
+
|
|
try:
|
|
try:
|
|
# Ensure the pgvector extension is available
|
|
# Ensure the pgvector extension is available
|
|
self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
|
|
self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
|