abstract_integration_test.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import logging
  2. import os
  3. import time
  4. import docker
  5. import pytest
  6. from docker import DockerClient
  7. from pytest_docker.plugin import get_docker_ip
  8. from fastapi.testclient import TestClient
  9. from sqlalchemy import text, create_engine
  10. log = logging.getLogger(__name__)
  11. def get_fast_api_client():
  12. from main import app
  13. with TestClient(app) as c:
  14. return c
  15. class AbstractIntegrationTest:
  16. BASE_PATH = None
  17. def create_url(self, path):
  18. if self.BASE_PATH is None:
  19. raise Exception("BASE_PATH is not set")
  20. parts = self.BASE_PATH.split("/")
  21. parts = [part.strip() for part in parts if part.strip() != ""]
  22. path_parts = path.split("/")
  23. path_parts = [part.strip() for part in path_parts if part.strip() != ""]
  24. return "/".join(parts + path_parts)
  25. @classmethod
  26. def setup_class(cls):
  27. pass
  28. def setup_method(self):
  29. pass
  30. @classmethod
  31. def teardown_class(cls):
  32. pass
  33. def teardown_method(self):
  34. pass
  35. class AbstractPostgresTest(AbstractIntegrationTest):
  36. DOCKER_CONTAINER_NAME = "postgres-test-container-will-get-deleted"
  37. docker_client: DockerClient
  38. def get_db(self):
  39. from apps.webui.internal.db import SessionLocal
  40. return SessionLocal()
  41. @classmethod
  42. def _create_db_url(cls, env_vars_postgres: dict) -> str:
  43. host = get_docker_ip()
  44. user = env_vars_postgres["POSTGRES_USER"]
  45. pw = env_vars_postgres["POSTGRES_PASSWORD"]
  46. port = 8081
  47. db = env_vars_postgres["POSTGRES_DB"]
  48. return f"postgresql://{user}:{pw}@{host}:{port}/{db}"
  49. @classmethod
  50. def setup_class(cls):
  51. super().setup_class()
  52. try:
  53. env_vars_postgres = {
  54. "POSTGRES_USER": "user",
  55. "POSTGRES_PASSWORD": "example",
  56. "POSTGRES_DB": "openwebui",
  57. }
  58. cls.docker_client = docker.from_env()
  59. cls.docker_client.containers.run(
  60. "postgres:16.2",
  61. detach=True,
  62. environment=env_vars_postgres,
  63. name=cls.DOCKER_CONTAINER_NAME,
  64. ports={5432: ("0.0.0.0", 8081)},
  65. command="postgres -c log_statement=all",
  66. )
  67. time.sleep(0.5)
  68. database_url = cls._create_db_url(env_vars_postgres)
  69. os.environ["DATABASE_URL"] = database_url
  70. retries = 10
  71. db = None
  72. while retries > 0:
  73. try:
  74. from config import BACKEND_DIR
  75. db = create_engine(database_url, pool_pre_ping=True)
  76. db = db.connect()
  77. log.info("postgres is ready!")
  78. break
  79. except Exception as e:
  80. log.warning(e)
  81. time.sleep(3)
  82. retries -= 1
  83. if db:
  84. # import must be after setting env!
  85. cls.fast_api_client = get_fast_api_client()
  86. db.close()
  87. else:
  88. raise Exception("Could not connect to Postgres")
  89. except Exception as ex:
  90. log.error(ex)
  91. cls.teardown_class()
  92. pytest.fail(f"Could not setup test environment: {ex}")
  93. def _check_db_connection(self):
  94. retries = 10
  95. while retries > 0:
  96. try:
  97. self.db_session.execute(text("SELECT 1"))
  98. self.db_session.commit()
  99. break
  100. except Exception as e:
  101. self.db_session.rollback()
  102. log.warning(e)
  103. time.sleep(3)
  104. retries -= 1
  105. def setup_method(self):
  106. super().setup_method()
  107. self.db_session = self.get_db()
  108. self._check_db_connection()
  109. @classmethod
  110. def teardown_class(cls) -> None:
  111. super().teardown_class()
  112. cls.docker_client.containers.get(cls.DOCKER_CONTAINER_NAME).remove(force=True)
  113. def teardown_method(self):
  114. # rollback everything not yet committed
  115. self.db_session.commit()
  116. # truncate all tables
  117. tables = [
  118. "auth",
  119. "chat",
  120. "chatidtag",
  121. "document",
  122. "memory",
  123. "model",
  124. "prompt",
  125. "tag",
  126. '"user"',
  127. ]
  128. for table in tables:
  129. self.db_session.execute(text(f"TRUNCATE TABLE {table}"))
  130. self.db_session.commit()