abstract_integration_test.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. import logging
  2. import os
  3. import time
  4. import docker
  5. import pytest
  6. from docker import DockerClient
  7. from pytest_docker.plugin import get_docker_ip
  8. from fastapi.testclient import TestClient
  9. from sqlalchemy import text, create_engine
  10. log = logging.getLogger(__name__)
  11. def get_fast_api_client():
  12. from main import app
  13. with TestClient(app) as c:
  14. return c
  15. class AbstractIntegrationTest:
  16. BASE_PATH = None
  17. def create_url(self, path="", query_params=None):
  18. if self.BASE_PATH is None:
  19. raise Exception("BASE_PATH is not set")
  20. parts = self.BASE_PATH.split("/")
  21. parts = [part.strip() for part in parts if part.strip() != ""]
  22. path_parts = path.split("/")
  23. path_parts = [part.strip() for part in path_parts if part.strip() != ""]
  24. query_parts = ""
  25. if query_params:
  26. query_parts = "&".join(
  27. [f"{key}={value}" for key, value in query_params.items()]
  28. )
  29. query_parts = f"?{query_parts}"
  30. return "/".join(parts + path_parts) + query_parts
  31. @classmethod
  32. def setup_class(cls):
  33. pass
  34. def setup_method(self):
  35. pass
  36. @classmethod
  37. def teardown_class(cls):
  38. pass
  39. def teardown_method(self):
  40. pass
  41. class AbstractPostgresTest(AbstractIntegrationTest):
  42. DOCKER_CONTAINER_NAME = "postgres-test-container-will-get-deleted"
  43. docker_client: DockerClient
  44. @classmethod
  45. def _create_db_url(cls, env_vars_postgres: dict) -> str:
  46. host = get_docker_ip()
  47. user = env_vars_postgres["POSTGRES_USER"]
  48. pw = env_vars_postgres["POSTGRES_PASSWORD"]
  49. port = 8081
  50. db = env_vars_postgres["POSTGRES_DB"]
  51. return f"postgresql://{user}:{pw}@{host}:{port}/{db}"
  52. @classmethod
  53. def setup_class(cls):
  54. super().setup_class()
  55. try:
  56. env_vars_postgres = {
  57. "POSTGRES_USER": "user",
  58. "POSTGRES_PASSWORD": "example",
  59. "POSTGRES_DB": "openwebui",
  60. }
  61. cls.docker_client = docker.from_env()
  62. cls.docker_client.containers.run(
  63. "postgres:16.2",
  64. detach=True,
  65. environment=env_vars_postgres,
  66. name=cls.DOCKER_CONTAINER_NAME,
  67. ports={5432: ("0.0.0.0", 8081)},
  68. command="postgres -c log_statement=all",
  69. )
  70. time.sleep(0.5)
  71. database_url = cls._create_db_url(env_vars_postgres)
  72. os.environ["DATABASE_URL"] = database_url
  73. retries = 10
  74. db = None
  75. while retries > 0:
  76. try:
  77. from open_webui.config import OPEN_WEBUI_DIR
  78. db = create_engine(database_url, pool_pre_ping=True)
  79. db = db.connect()
  80. log.info("postgres is ready!")
  81. break
  82. except Exception as e:
  83. log.warning(e)
  84. time.sleep(3)
  85. retries -= 1
  86. if db:
  87. # import must be after setting env!
  88. cls.fast_api_client = get_fast_api_client()
  89. db.close()
  90. else:
  91. raise Exception("Could not connect to Postgres")
  92. except Exception as ex:
  93. log.error(ex)
  94. cls.teardown_class()
  95. pytest.fail(f"Could not setup test environment: {ex}")
  96. def _check_db_connection(self):
  97. from open_webui.internal.db import Session
  98. retries = 10
  99. while retries > 0:
  100. try:
  101. Session.execute(text("SELECT 1"))
  102. Session.commit()
  103. break
  104. except Exception as e:
  105. Session.rollback()
  106. log.warning(e)
  107. time.sleep(3)
  108. retries -= 1
  109. def setup_method(self):
  110. super().setup_method()
  111. self._check_db_connection()
  112. @classmethod
  113. def teardown_class(cls) -> None:
  114. super().teardown_class()
  115. cls.docker_client.containers.get(cls.DOCKER_CONTAINER_NAME).remove(force=True)
  116. def teardown_method(self):
  117. from open_webui.internal.db import Session
  118. # rollback everything not yet committed
  119. Session.commit()
  120. # truncate all tables
  121. tables = [
  122. "auth",
  123. "chat",
  124. "chatidtag",
  125. "document",
  126. "memory",
  127. "model",
  128. "prompt",
  129. "tag",
  130. '"user"',
  131. ]
  132. for table in tables:
  133. Session.execute(text(f"TRUNCATE TABLE {table}"))
  134. Session.commit()