abstract_integration_test.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. import logging
  2. import os
  3. import time
  4. import docker
  5. import pytest
  6. from docker import DockerClient
  7. from pytest_docker.plugin import get_docker_ip
  8. from fastapi.testclient import TestClient
  9. from sqlalchemy import text, create_engine
  10. log = logging.getLogger(__name__)
  11. def get_fast_api_client():
  12. from main import app
  13. with TestClient(app) as c:
  14. return c
  15. class AbstractIntegrationTest:
  16. BASE_PATH = None
  17. def create_url(self, path):
  18. if self.BASE_PATH is None:
  19. raise Exception("BASE_PATH is not set")
  20. parts = self.BASE_PATH.split("/")
  21. parts = [part.strip() for part in parts if part.strip() != ""]
  22. path_parts = path.split("/")
  23. path_parts = [part.strip() for part in path_parts if part.strip() != ""]
  24. return "/".join(parts + path_parts)
  25. @classmethod
  26. def setup_class(cls):
  27. pass
  28. def setup_method(self):
  29. pass
  30. @classmethod
  31. def teardown_class(cls):
  32. pass
  33. def teardown_method(self):
  34. pass
  35. class AbstractPostgresTest(AbstractIntegrationTest):
  36. DOCKER_CONTAINER_NAME = "postgres-test-container-will-get-deleted"
  37. docker_client: DockerClient
  38. @classmethod
  39. def _create_db_url(cls, env_vars_postgres: dict) -> str:
  40. host = get_docker_ip()
  41. user = env_vars_postgres["POSTGRES_USER"]
  42. pw = env_vars_postgres["POSTGRES_PASSWORD"]
  43. port = 8081
  44. db = env_vars_postgres["POSTGRES_DB"]
  45. return f"postgresql://{user}:{pw}@{host}:{port}/{db}"
  46. @classmethod
  47. def setup_class(cls):
  48. super().setup_class()
  49. try:
  50. env_vars_postgres = {
  51. "POSTGRES_USER": "user",
  52. "POSTGRES_PASSWORD": "example",
  53. "POSTGRES_DB": "openwebui",
  54. }
  55. cls.docker_client = docker.from_env()
  56. cls.docker_client.containers.run(
  57. "postgres:16.2",
  58. detach=True,
  59. environment=env_vars_postgres,
  60. name=cls.DOCKER_CONTAINER_NAME,
  61. ports={5432: ("0.0.0.0", 8081)},
  62. command="postgres -c log_statement=all",
  63. )
  64. time.sleep(0.5)
  65. database_url = cls._create_db_url(env_vars_postgres)
  66. os.environ["DATABASE_URL"] = database_url
  67. retries = 10
  68. db = None
  69. while retries > 0:
  70. try:
  71. from config import BACKEND_DIR
  72. db = create_engine(database_url, pool_pre_ping=True)
  73. db = db.connect()
  74. log.info("postgres is ready!")
  75. break
  76. except Exception as e:
  77. log.warning(e)
  78. time.sleep(3)
  79. retries -= 1
  80. if db:
  81. # import must be after setting env!
  82. cls.fast_api_client = get_fast_api_client()
  83. db.close()
  84. else:
  85. raise Exception("Could not connect to Postgres")
  86. except Exception as ex:
  87. log.error(ex)
  88. cls.teardown_class()
  89. pytest.fail(f"Could not setup test environment: {ex}")
  90. def _check_db_connection(self):
  91. from apps.webui.internal.db import Session
  92. retries = 10
  93. while retries > 0:
  94. try:
  95. Session.execute(text("SELECT 1"))
  96. Session.commit()
  97. break
  98. except Exception as e:
  99. Session.rollback()
  100. log.warning(e)
  101. time.sleep(3)
  102. retries -= 1
  103. def setup_method(self):
  104. super().setup_method()
  105. self._check_db_connection()
  106. @classmethod
  107. def teardown_class(cls) -> None:
  108. super().teardown_class()
  109. cls.docker_client.containers.get(cls.DOCKER_CONTAINER_NAME).remove(force=True)
  110. def teardown_method(self):
  111. from apps.webui.internal.db import Session
  112. # rollback everything not yet committed
  113. Session.commit()
  114. # truncate all tables
  115. tables = [
  116. "auth",
  117. "chat",
  118. "chatidtag",
  119. "document",
  120. "memory",
  121. "model",
  122. "prompt",
  123. "tag",
  124. '"user"',
  125. ]
  126. for table in tables:
  127. Session.execute(text(f"TRUNCATE TABLE {table}"))
  128. Session.commit()