wrappers.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. from contextvars import ContextVar
  2. from peewee import PostgresqlDatabase, InterfaceError as PeeWeeInterfaceError, MySQLDatabase, _ConnectionState
  3. from playhouse.db_url import register_database
  4. from playhouse.pool import PooledPostgresqlDatabase, PooledMySQLDatabase
  5. from playhouse.shortcuts import ReconnectMixin
  6. from psycopg2 import OperationalError
  7. from psycopg2.errors import InterfaceError
  8. db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
  9. db_state = ContextVar("db_state", default=db_state_default.copy())
  10. class PeeweeConnectionState(_ConnectionState):
  11. def __init__(self, **kwargs):
  12. super().__setattr__("_state", db_state)
  13. super().__init__(**kwargs)
  14. def __setattr__(self, name, value):
  15. self._state.get()[name] = value
  16. def __getattr__(self, name):
  17. return self._state.get()[name]
  18. class CustomReconnectMixin(ReconnectMixin):
  19. reconnect_errors = (
  20. # default ReconnectMixin exceptions (MySQL specific)
  21. *ReconnectMixin.reconnect_errors,
  22. # psycopg2
  23. (OperationalError, 'termin'),
  24. (InterfaceError, 'closed'),
  25. # peewee
  26. (PeeWeeInterfaceError, 'closed'),
  27. )
  28. class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
  29. pass
  30. class ReconnectingPooledPostgresqlDatabase(CustomReconnectMixin, PooledPostgresqlDatabase):
  31. pass
  32. class ReconnectingMySQLDatabase(CustomReconnectMixin, MySQLDatabase):
  33. pass
  34. class ReconnectingPooledMySQLDatabase(CustomReconnectMixin, PooledMySQLDatabase):
  35. pass
  36. def register_peewee_databases():
  37. register_database(MySQLDatabase, 'mysql')
  38. register_database(PooledMySQLDatabase, 'mysql+pool')
  39. register_database(ReconnectingPostgresqlDatabase, 'postgres', 'postgresql')
  40. register_database(ReconnectingPooledPostgresqlDatabase, 'postgres+pool', 'postgresql+pool')