Ver código fonte

enh: handle peewee migration

Timothy J. Baek 10 meses atrás
pai
commit
1436bb7c61

+ 33 - 3
backend/apps/webui/internal/db.py

@@ -2,6 +2,10 @@ import os
 import logging
 import json
 from contextlib import contextmanager
+
+from peewee_migrate import Router
+from apps.webui.internal.wrappers import register_connection
+
 from typing import Optional, Any
 from typing_extensions import Self
 
@@ -46,6 +50,35 @@ if os.path.exists(f"{DATA_DIR}/ollama.db"):
 else:
     pass
 
+
+# Workaround to handle the peewee migration
+# This is required to ensure the peewee migration is handled before the alembic migration
+def handle_peewee_migration():
+    try:
+        db = register_connection(DATABASE_URL)
+        migrate_dir = BACKEND_DIR / "apps" / "webui" / "internal" / "migrations"
+        router = Router(db, logger=log, migrate_dir=migrate_dir)
+        router.run()
+        db.close()
+
+        # check if db connection has been closed
+
+    except Exception as e:
+        log.error(f"Failed to initialize the database connection: {e}")
+        raise
+
+    finally:
+        # Properly closing the database connection
+        if db and not db.is_closed():
+            db.close()
+
+        # Assert if db connection has been closed
+        assert db.is_closed(), "Database connection is still open."
+
+
+handle_peewee_migration()
+
+
 SQLALCHEMY_DATABASE_URL = DATABASE_URL
 if "sqlite" in SQLALCHEMY_DATABASE_URL:
     engine = create_engine(
@@ -62,9 +95,6 @@ Base = declarative_base()
 Session = scoped_session(SessionLocal)
 
 
-from contextlib import contextmanager
-
-
 # Dependency
 def get_session():
     db = SessionLocal()

+ 72 - 0
backend/apps/webui/internal/wrappers.py

@@ -0,0 +1,72 @@
+from contextvars import ContextVar
+from peewee import *
+from peewee import PostgresqlDatabase, InterfaceError as PeeWeeInterfaceError
+
+import logging
+from playhouse.db_url import connect, parse
+from playhouse.shortcuts import ReconnectMixin
+
+from config import SRC_LOG_LEVELS
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["DB"])
+
+db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
+db_state = ContextVar("db_state", default=db_state_default.copy())
+
+
+class PeeweeConnectionState(object):
+    def __init__(self, **kwargs):
+        super().__setattr__("_state", db_state)
+        super().__init__(**kwargs)
+
+    def __setattr__(self, name, value):
+        self._state.get()[name] = value
+
+    def __getattr__(self, name):
+        value = self._state.get()[name]
+        return value
+
+
+class CustomReconnectMixin(ReconnectMixin):
+    reconnect_errors = (
+        # psycopg2
+        (OperationalError, "termin"),
+        (InterfaceError, "closed"),
+        # peewee
+        (PeeWeeInterfaceError, "closed"),
+    )
+
+
+class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
+    pass
+
+
+def register_connection(db_url):
+    db = connect(db_url)
+    if isinstance(db, PostgresqlDatabase):
+        # Enable autoconnect for SQLite databases, managed by Peewee
+        db.autoconnect = True
+        db.reuse_if_open = True
+        log.info("Connected to PostgreSQL database")
+
+        # Get the connection details
+        connection = parse(db_url)
+
+        # Use our custom database class that supports reconnection
+        db = ReconnectingPostgresqlDatabase(
+            connection["database"],
+            user=connection["user"],
+            password=connection["password"],
+            host=connection["host"],
+            port=connection["port"],
+        )
+        db.connect(reuse_if_open=True)
+    elif isinstance(db, SqliteDatabase):
+        # Enable autoconnect for SQLite databases, managed by Peewee
+        db.autoconnect = True
+        db.reuse_if_open = True
+        log.info("Connected to SQLite database")
+    else:
+        raise ValueError("Unsupported database connection")
+    return db