Преглед изворни кода

Merge pull request #3221 from perfectra1n/feature-external-db-reconnect

feat: external db reconnect
Timothy Jaeryang Baek пре 10 месеци
родитељ
комит
1e0453221d

+ 20 - 0
.github/workflows/integration-test.yml

@@ -170,6 +170,26 @@ jobs:
               echo "Server has stopped"
               echo "Server has stopped"
               exit 1
               exit 1
           fi
           fi
+          
+          # Check that service will reconnect to postgres when connection will be closed
+          status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/api/tags)
+          if [[ "$status_code" -ne 200 ]] ; then
+            echo "Server has failed before postgres reconnect check"
+            exit 1
+          fi
+
+          echo "Terminating all connections to postgres..."
+          python -c "import os, psycopg2 as pg2; \
+            conn = pg2.connect(dsn=os.environ['DATABASE_URL'].replace('+pool', '')); \
+            cur = conn.cursor(); \
+            cur.execute('SELECT pg_terminate_backend(psa.pid) FROM pg_stat_activity psa WHERE datname = current_database() AND pid <> pg_backend_pid();')"
+          
+          status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/api/tags)
+          if [[ "$status_code" -ne 200 ]] ; then
+            echo "Server has not reconnected to postgres after connection was closed: returned status $status_code"
+            exit 1
+          fi
+
 
 
 #      - name: Test backend with MySQL
 #      - name: Test backend with MySQL
 #        if: success() || steps.sqlite.conclusion == 'failure' || steps.postgres.conclusion == 'failure'
 #        if: success() || steps.sqlite.conclusion == 'failure' || steps.postgres.conclusion == 'failure'

+ 21 - 8
backend/apps/webui/internal/db.py

@@ -1,16 +1,16 @@
+import os
+import logging
 import json
 import json
 
 
 from peewee import *
 from peewee import *
 from peewee_migrate import Router
 from peewee_migrate import Router
-from playhouse.db_url import connect
+
+from apps.webui.internal.wrappers import register_connection
 from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR
 from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR
-import os
-import logging
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["DB"])
 log.setLevel(SRC_LOG_LEVELS["DB"])
 
 
-
 class JSONField(TextField):
 class JSONField(TextField):
     def db_value(self, value):
     def db_value(self, value):
         return json.dumps(value)
         return json.dumps(value)
@@ -19,7 +19,6 @@ class JSONField(TextField):
         if value is not None:
         if value is not None:
             return json.loads(value)
             return json.loads(value)
 
 
-
 # Check if the file exists
 # Check if the file exists
 if os.path.exists(f"{DATA_DIR}/ollama.db"):
 if os.path.exists(f"{DATA_DIR}/ollama.db"):
     # Rename the file
     # Rename the file
@@ -28,12 +27,26 @@ if os.path.exists(f"{DATA_DIR}/ollama.db"):
 else:
 else:
     pass
     pass
 
 
-DB = connect(DATABASE_URL)
-log.info(f"Connected to a {DB.__class__.__name__} database.")
+
+# The `register_connection` function encapsulates the logic for setting up 
+# the database connection based on the connection string, while `connect` 
+# is a Peewee-specific method to manage the connection state and avoid errors 
+# when a connection is already open.
+try:
+    DB = register_connection(DATABASE_URL)
+    log.info(f"Connected to a {DB.__class__.__name__} database.")
+except Exception as e:
+    log.error(f"Failed to initialize the database connection: {e}")
+    raise
+
 router = Router(
 router = Router(
     DB,
     DB,
     migrate_dir=BACKEND_DIR / "apps" / "webui" / "internal" / "migrations",
     migrate_dir=BACKEND_DIR / "apps" / "webui" / "internal" / "migrations",
     logger=log,
     logger=log,
 )
 )
 router.run()
 router.run()
-DB.connect(reuse_if_open=True)
+try:
+    DB.connect(reuse_if_open=True)
+except OperationalError as e:
+    log.info(f"Failed to connect to database again due to: {e}")
+    pass

+ 72 - 0
backend/apps/webui/internal/wrappers.py

@@ -0,0 +1,72 @@
+from contextvars import ContextVar
+from peewee import *
+from peewee import PostgresqlDatabase, InterfaceError as PeeWeeInterfaceError
+
+import logging
+from playhouse.db_url import connect, parse
+from playhouse.shortcuts import ReconnectMixin
+
+from config import SRC_LOG_LEVELS
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["DB"])
+
+db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
+db_state = ContextVar("db_state", default=db_state_default.copy())
+
+
+class PeeweeConnectionState(object):
+    def __init__(self, **kwargs):
+        super().__setattr__("_state", db_state)
+        super().__init__(**kwargs)
+
+    def __setattr__(self, name, value):
+        self._state.get()[name] = value
+
+    def __getattr__(self, name):
+        value = self._state.get()[name]
+        return value
+
+
+class CustomReconnectMixin(ReconnectMixin):
+    reconnect_errors = (
+        # psycopg2
+        (OperationalError, "termin"),
+        (InterfaceError, "closed"),
+        # peewee
+        (PeeWeeInterfaceError, "closed"),
+    )
+
+
+class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
+    pass
+
+
+def register_connection(db_url):
+    db = connect(db_url)
+    if isinstance(db, PostgresqlDatabase):
+        # Enable autoconnect for SQLite databases, managed by Peewee
+        db.autoconnect = True
+        db.reuse_if_open = True
+        log.info("Connected to PostgreSQL database")
+
+        # Get the connection details
+        connection = parse(db_url)
+
+        # Use our custom database class that supports reconnection
+        db = ReconnectingPostgresqlDatabase(
+            connection["database"],
+            user=connection["user"],
+            password=connection["password"],
+            host=connection["host"],
+            port=connection["port"],
+        )
+        db.connect(reuse_if_open=True)
+    elif isinstance(db, SqliteDatabase):
+        # Enable autoconnect for SQLite databases, managed by Peewee
+        db.autoconnect = True
+        db.reuse_if_open = True
+        log.info("Connected to SQLite database")
+    else:
+        raise ValueError("Unsupported database connection")
+    return db