123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161 |
- import logging
- import os
- import time
- import docker
- import pytest
- from docker import DockerClient
- from pytest_docker.plugin import get_docker_ip
- from fastapi.testclient import TestClient
- from sqlalchemy import text, create_engine
- log = logging.getLogger(__name__)
- def get_fast_api_client():
- from main import app
- with TestClient(app) as c:
- return c
- class AbstractIntegrationTest:
- BASE_PATH = None
- def create_url(self, path="", query_params=None):
- if self.BASE_PATH is None:
- raise Exception("BASE_PATH is not set")
- parts = self.BASE_PATH.split("/")
- parts = [part.strip() for part in parts if part.strip() != ""]
- path_parts = path.split("/")
- path_parts = [part.strip() for part in path_parts if part.strip() != ""]
- query_parts = ""
- if query_params:
- query_parts = "&".join(
- [f"{key}={value}" for key, value in query_params.items()]
- )
- query_parts = f"?{query_parts}"
- return "/".join(parts + path_parts) + query_parts
- @classmethod
- def setup_class(cls):
- pass
- def setup_method(self):
- pass
- @classmethod
- def teardown_class(cls):
- pass
- def teardown_method(self):
- pass
- class AbstractPostgresTest(AbstractIntegrationTest):
- DOCKER_CONTAINER_NAME = "postgres-test-container-will-get-deleted"
- docker_client: DockerClient
- @classmethod
- def _create_db_url(cls, env_vars_postgres: dict) -> str:
- host = get_docker_ip()
- user = env_vars_postgres["POSTGRES_USER"]
- pw = env_vars_postgres["POSTGRES_PASSWORD"]
- port = 8081
- db = env_vars_postgres["POSTGRES_DB"]
- return f"postgresql://{user}:{pw}@{host}:{port}/{db}"
- @classmethod
- def setup_class(cls):
- super().setup_class()
- try:
- env_vars_postgres = {
- "POSTGRES_USER": "user",
- "POSTGRES_PASSWORD": "example",
- "POSTGRES_DB": "openwebui",
- }
- cls.docker_client = docker.from_env()
- cls.docker_client.containers.run(
- "postgres:16.2",
- detach=True,
- environment=env_vars_postgres,
- name=cls.DOCKER_CONTAINER_NAME,
- ports={5432: ("0.0.0.0", 8081)},
- command="postgres -c log_statement=all",
- )
- time.sleep(0.5)
- database_url = cls._create_db_url(env_vars_postgres)
- os.environ["DATABASE_URL"] = database_url
- retries = 10
- db = None
- while retries > 0:
- try:
- from open_webui.config import OPEN_WEBUI_DIR
- db = create_engine(database_url, pool_pre_ping=True)
- db = db.connect()
- log.info("postgres is ready!")
- break
- except Exception as e:
- log.warning(e)
- time.sleep(3)
- retries -= 1
- if db:
- # import must be after setting env!
- cls.fast_api_client = get_fast_api_client()
- db.close()
- else:
- raise Exception("Could not connect to Postgres")
- except Exception as ex:
- log.error(ex)
- cls.teardown_class()
- pytest.fail(f"Could not setup test environment: {ex}")
- def _check_db_connection(self):
- from open_webui.internal.db import Session
- retries = 10
- while retries > 0:
- try:
- Session.execute(text("SELECT 1"))
- Session.commit()
- break
- except Exception as e:
- Session.rollback()
- log.warning(e)
- time.sleep(3)
- retries -= 1
- def setup_method(self):
- super().setup_method()
- self._check_db_connection()
- @classmethod
- def teardown_class(cls) -> None:
- super().teardown_class()
- cls.docker_client.containers.get(cls.DOCKER_CONTAINER_NAME).remove(force=True)
- def teardown_method(self):
- from open_webui.internal.db import Session
- # rollback everything not yet committed
- Session.commit()
- # truncate all tables
- tables = [
- "auth",
- "chat",
- "chatidtag",
- "document",
- "memory",
- "model",
- "prompt",
- "tag",
- '"user"',
- ]
- for table in tables:
- Session.execute(text(f"TRUNCATE TABLE {table}"))
- Session.commit()
|