Sfoglia il codice sorgente

Reconnect to postgresql & mysql external databases when getting disconnected

Беклемишев Петр Алексеевич 11 mesi fa
parent
commit
dfbc125947

+ 9 - 0
backend/apps/webui/internal/db.py

@@ -7,6 +7,12 @@ from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR
 import os
 import logging
 
+from peewee_migrate import Router
+from playhouse.db_url import connect
+
+from apps.webui.internal.wrappers import PeeweeConnectionState, register_peewee_databases
+from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL
+
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["DB"])
 
@@ -20,6 +26,8 @@ class JSONField(TextField):
             return json.loads(value)
 
 
+register_peewee_databases()
+
 # Check if the file exists
 if os.path.exists(f"{DATA_DIR}/ollama.db"):
     # Rename the file
@@ -29,6 +37,7 @@ else:
     pass
 
 DB = connect(DATABASE_URL)
+DB._state = PeeweeConnectionState()
 log.info(f"Connected to a {DB.__class__.__name__} database.")
 router = Router(
     DB,

+ 59 - 0
backend/apps/webui/internal/wrappers.py

@@ -0,0 +1,59 @@
+from contextvars import ContextVar
+
+from peewee import PostgresqlDatabase, InterfaceError as PeeWeeInterfaceError, MySQLDatabase, _ConnectionState
+from playhouse.db_url import register_database
+from playhouse.pool import PooledPostgresqlDatabase, PooledMySQLDatabase
+from playhouse.shortcuts import ReconnectMixin
+from psycopg2 import OperationalError
+from psycopg2.errors import InterfaceError
+
+
+db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
+db_state = ContextVar("db_state", default=db_state_default.copy())
+
+
+class PeeweeConnectionState(_ConnectionState):
+    def __init__(self, **kwargs):
+        super().__setattr__("_state", db_state)
+        super().__init__(**kwargs)
+
+    def __setattr__(self, name, value):
+        self._state.get()[name] = value
+
+    def __getattr__(self, name):
+        return self._state.get()[name]
+
+
+class CustomReconnectMixin(ReconnectMixin):
+    reconnect_errors = (
+        # default ReconnectMixin exceptions (MySQL specific)
+        *ReconnectMixin.reconnect_errors,
+        # psycopg2
+        (OperationalError, 'termin'),
+        (InterfaceError, 'closed'),
+        # peewee
+        (PeeWeeInterfaceError, 'closed'),
+    )
+
+
+class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
+    pass
+
+
+class ReconnectingPooledPostgresqlDatabase(CustomReconnectMixin, PooledPostgresqlDatabase):
+    pass
+
+
+class ReconnectingMySQLDatabase(CustomReconnectMixin, MySQLDatabase):
+    pass
+
+
+class ReconnectingPooledMySQLDatabase(CustomReconnectMixin, PooledMySQLDatabase):
+    pass
+
+
+def register_peewee_databases():
+    register_database(MySQLDatabase, 'mysql')
+    register_database(PooledMySQLDatabase, 'mysql+pool')
+    register_database(ReconnectingPostgresqlDatabase, 'postgres', 'postgresql')
+    register_database(ReconnectingPooledPostgresqlDatabase, 'postgres+pool', 'postgresql+pool')