import io import boto3 import pytest from botocore.exceptions import ClientError from moto import mock_aws from open_webui.storage import provider def mock_upload_dir(monkeypatch, tmp_path): """Fixture to monkey-patch the UPLOAD_DIR and create a temporary directory.""" directory = tmp_path / "uploads" directory.mkdir() monkeypatch.setattr(provider, "UPLOAD_DIR", str(directory)) return directory def test_imports(): provider.StorageProvider provider.LocalStorageProvider provider.S3StorageProvider provider.Storage def test_get_storage_provider(): Storage = provider.get_storage_provider("local") assert isinstance(Storage, provider.LocalStorageProvider) Storage = provider.get_storage_provider("s3") assert isinstance(Storage, provider.S3StorageProvider) with pytest.raises(RuntimeError): provider.get_storage_provider("invalid") def test_class_instantiation(): with pytest.raises(TypeError): provider.StorageProvider() with pytest.raises(TypeError): class Test(provider.StorageProvider): pass Test() provider.LocalStorageProvider() provider.S3StorageProvider() class TestLocalStorageProvider: Storage = provider.LocalStorageProvider() file_content = b"test content" file_bytesio = io.BytesIO(file_content) filename = "test.txt" filename_extra = "test_exyta.txt" file_bytesio_empty = io.BytesIO() def test_upload_file(self, monkeypatch, tmp_path): upload_dir = mock_upload_dir(monkeypatch, tmp_path) contents, file_path = self.Storage.upload_file(self.file_bytesio, self.filename) assert (upload_dir / self.filename).exists() assert (upload_dir / self.filename).read_bytes() == self.file_content assert contents == self.file_content assert file_path == str(upload_dir / self.filename) with pytest.raises(ValueError): self.Storage.upload_file(self.file_bytesio_empty, self.filename) def test_get_file(self, monkeypatch, tmp_path): upload_dir = mock_upload_dir(monkeypatch, tmp_path) file_path = str(upload_dir / self.filename) file_path_return = self.Storage.get_file(file_path) assert file_path == file_path_return def test_delete_file(self, monkeypatch, tmp_path): upload_dir = mock_upload_dir(monkeypatch, tmp_path) (upload_dir / self.filename).write_bytes(self.file_content) assert (upload_dir / self.filename).exists() file_path = str(upload_dir / self.filename) self.Storage.delete_file(file_path) assert not (upload_dir / self.filename).exists() def test_delete_all_files(self, monkeypatch, tmp_path): upload_dir = mock_upload_dir(monkeypatch, tmp_path) (upload_dir / self.filename).write_bytes(self.file_content) (upload_dir / self.filename_extra).write_bytes(self.file_content) self.Storage.delete_all_files() assert not (upload_dir / self.filename).exists() assert not (upload_dir / self.filename_extra).exists() @mock_aws class TestS3StorageProvider: Storage = provider.S3StorageProvider() Storage.bucket_name = "my-bucket" s3_client = boto3.resource("s3", region_name="us-east-1") file_content = b"test content" filename = "test.txt" filename_extra = "test_exyta.txt" file_bytesio_empty = io.BytesIO() def test_upload_file(self, monkeypatch, tmp_path): upload_dir = mock_upload_dir(monkeypatch, tmp_path) # S3 checks with pytest.raises(Exception): self.Storage.upload_file(io.BytesIO(self.file_content), self.filename) self.s3_client.create_bucket(Bucket=self.Storage.bucket_name) contents, s3_file_path = self.Storage.upload_file( io.BytesIO(self.file_content), self.filename ) object = self.s3_client.Object(self.Storage.bucket_name, self.filename) assert self.file_content == object.get()["Body"].read() # local checks assert (upload_dir / self.filename).exists() assert (upload_dir / self.filename).read_bytes() == self.file_content assert contents == self.file_content assert s3_file_path == "s3://" + self.Storage.bucket_name + "/" + self.filename with pytest.raises(ValueError): self.Storage.upload_file(self.file_bytesio_empty, self.filename) def test_get_file(self, monkeypatch, tmp_path): upload_dir = mock_upload_dir(monkeypatch, tmp_path) self.s3_client.create_bucket(Bucket=self.Storage.bucket_name) contents, s3_file_path = self.Storage.upload_file( io.BytesIO(self.file_content), self.filename ) file_path = self.Storage.get_file(s3_file_path) assert file_path == str(upload_dir / self.filename) assert (upload_dir / self.filename).exists() def test_delete_file(self, monkeypatch, tmp_path): upload_dir = mock_upload_dir(monkeypatch, tmp_path) self.s3_client.create_bucket(Bucket=self.Storage.bucket_name) contents, s3_file_path = self.Storage.upload_file( io.BytesIO(self.file_content), self.filename ) assert (upload_dir / self.filename).exists() self.Storage.delete_file(s3_file_path) assert not (upload_dir / self.filename).exists() with pytest.raises(ClientError) as exc: self.s3_client.Object(self.Storage.bucket_name, self.filename).load() error = exc.value.response["Error"] assert error["Code"] == "404" assert error["Message"] == "Not Found" def test_delete_all_files(self, monkeypatch, tmp_path): upload_dir = mock_upload_dir(monkeypatch, tmp_path) # create 2 files self.s3_client.create_bucket(Bucket=self.Storage.bucket_name) self.Storage.upload_file(io.BytesIO(self.file_content), self.filename) object = self.s3_client.Object(self.Storage.bucket_name, self.filename) assert self.file_content == object.get()["Body"].read() assert (upload_dir / self.filename).exists() assert (upload_dir / self.filename).read_bytes() == self.file_content self.Storage.upload_file(io.BytesIO(self.file_content), self.filename_extra) object = self.s3_client.Object(self.Storage.bucket_name, self.filename_extra) assert self.file_content == object.get()["Body"].read() assert (upload_dir / self.filename).exists() assert (upload_dir / self.filename).read_bytes() == self.file_content self.Storage.delete_all_files() assert not (upload_dir / self.filename).exists() with pytest.raises(ClientError) as exc: self.s3_client.Object(self.Storage.bucket_name, self.filename).load() error = exc.value.response["Error"] assert error["Code"] == "404" assert error["Message"] == "Not Found" assert not (upload_dir / self.filename_extra).exists() with pytest.raises(ClientError) as exc: self.s3_client.Object(self.Storage.bucket_name, self.filename_extra).load() error = exc.value.response["Error"] assert error["Code"] == "404" assert error["Message"] == "Not Found" self.Storage.delete_all_files() assert not (upload_dir / self.filename).exists() assert not (upload_dir / self.filename_extra).exists()