123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270 |
- import io
- import os
- import boto3
- import pytest
- from botocore.exceptions import ClientError
- from moto import mock_aws
- from open_webui.storage import provider
- from gcp_storage_emulator.server import create_server
- from google.cloud import storage
- def mock_upload_dir(monkeypatch, tmp_path):
- """Fixture to monkey-patch the UPLOAD_DIR and create a temporary directory."""
- directory = tmp_path / "uploads"
- directory.mkdir()
- monkeypatch.setattr(provider, "UPLOAD_DIR", str(directory))
- return directory
- def test_imports():
- provider.StorageProvider
- provider.LocalStorageProvider
- provider.S3StorageProvider
- provider.GCSStorageProvider
- provider.Storage
- def test_get_storage_provider():
- Storage = provider.get_storage_provider("local")
- assert isinstance(Storage, provider.LocalStorageProvider)
- Storage = provider.get_storage_provider("s3")
- assert isinstance(Storage, provider.S3StorageProvider)
- Storage = provider.get_storage_provider("gcs")
- assert isinstance(Storage, provider.GCSStorageProvider)
- with pytest.raises(RuntimeError):
- provider.get_storage_provider("invalid")
- def test_class_instantiation():
- with pytest.raises(TypeError):
- provider.StorageProvider()
- with pytest.raises(TypeError):
- class Test(provider.StorageProvider):
- pass
- Test()
- provider.LocalStorageProvider()
- provider.S3StorageProvider()
- provider.GCSStorageProvider()
- class TestLocalStorageProvider:
- Storage = provider.LocalStorageProvider()
- file_content = b"test content"
- file_bytesio = io.BytesIO(file_content)
- filename = "test.txt"
- filename_extra = "test_exyta.txt"
- file_bytesio_empty = io.BytesIO()
- def test_upload_file(self, monkeypatch, tmp_path):
- upload_dir = mock_upload_dir(monkeypatch, tmp_path)
- contents, file_path = self.Storage.upload_file(self.file_bytesio, self.filename)
- assert (upload_dir / self.filename).exists()
- assert (upload_dir / self.filename).read_bytes() == self.file_content
- assert contents == self.file_content
- assert file_path == str(upload_dir / self.filename)
- with pytest.raises(ValueError):
- self.Storage.upload_file(self.file_bytesio_empty, self.filename)
- def test_get_file(self, monkeypatch, tmp_path):
- upload_dir = mock_upload_dir(monkeypatch, tmp_path)
- file_path = str(upload_dir / self.filename)
- file_path_return = self.Storage.get_file(file_path)
- assert file_path == file_path_return
- def test_delete_file(self, monkeypatch, tmp_path):
- upload_dir = mock_upload_dir(monkeypatch, tmp_path)
- (upload_dir / self.filename).write_bytes(self.file_content)
- assert (upload_dir / self.filename).exists()
- file_path = str(upload_dir / self.filename)
- self.Storage.delete_file(file_path)
- assert not (upload_dir / self.filename).exists()
- def test_delete_all_files(self, monkeypatch, tmp_path):
- upload_dir = mock_upload_dir(monkeypatch, tmp_path)
- (upload_dir / self.filename).write_bytes(self.file_content)
- (upload_dir / self.filename_extra).write_bytes(self.file_content)
- self.Storage.delete_all_files()
- assert not (upload_dir / self.filename).exists()
- assert not (upload_dir / self.filename_extra).exists()
- @mock_aws
- class TestS3StorageProvider:
- Storage = provider.S3StorageProvider()
- Storage.bucket_name = "my-bucket"
- s3_client = boto3.resource("s3", region_name="us-east-1")
- file_content = b"test content"
- filename = "test.txt"
- filename_extra = "test_exyta.txt"
- file_bytesio_empty = io.BytesIO()
- def test_upload_file(self, monkeypatch, tmp_path):
- upload_dir = mock_upload_dir(monkeypatch, tmp_path)
- # S3 checks
- with pytest.raises(Exception):
- self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
- self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
- contents, s3_file_path = self.Storage.upload_file(
- io.BytesIO(self.file_content), self.filename
- )
- object = self.s3_client.Object(self.Storage.bucket_name, self.filename)
- assert self.file_content == object.get()["Body"].read()
- # local checks
- assert (upload_dir / self.filename).exists()
- assert (upload_dir / self.filename).read_bytes() == self.file_content
- assert contents == self.file_content
- assert s3_file_path == "s3://" + self.Storage.bucket_name + "/" + self.filename
- with pytest.raises(ValueError):
- self.Storage.upload_file(self.file_bytesio_empty, self.filename)
- def test_get_file(self, monkeypatch, tmp_path):
- upload_dir = mock_upload_dir(monkeypatch, tmp_path)
- self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
- contents, s3_file_path = self.Storage.upload_file(
- io.BytesIO(self.file_content), self.filename
- )
- file_path = self.Storage.get_file(s3_file_path)
- assert file_path == str(upload_dir / self.filename)
- assert (upload_dir / self.filename).exists()
- def test_delete_file(self, monkeypatch, tmp_path):
- upload_dir = mock_upload_dir(monkeypatch, tmp_path)
- self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
- contents, s3_file_path = self.Storage.upload_file(
- io.BytesIO(self.file_content), self.filename
- )
- assert (upload_dir / self.filename).exists()
- self.Storage.delete_file(s3_file_path)
- assert not (upload_dir / self.filename).exists()
- with pytest.raises(ClientError) as exc:
- self.s3_client.Object(self.Storage.bucket_name, self.filename).load()
- error = exc.value.response["Error"]
- assert error["Code"] == "404"
- assert error["Message"] == "Not Found"
- def test_delete_all_files(self, monkeypatch, tmp_path):
- upload_dir = mock_upload_dir(monkeypatch, tmp_path)
- # create 2 files
- self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
- self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
- object = self.s3_client.Object(self.Storage.bucket_name, self.filename)
- assert self.file_content == object.get()["Body"].read()
- assert (upload_dir / self.filename).exists()
- assert (upload_dir / self.filename).read_bytes() == self.file_content
- self.Storage.upload_file(io.BytesIO(self.file_content), self.filename_extra)
- object = self.s3_client.Object(self.Storage.bucket_name, self.filename_extra)
- assert self.file_content == object.get()["Body"].read()
- assert (upload_dir / self.filename).exists()
- assert (upload_dir / self.filename).read_bytes() == self.file_content
- self.Storage.delete_all_files()
- assert not (upload_dir / self.filename).exists()
- with pytest.raises(ClientError) as exc:
- self.s3_client.Object(self.Storage.bucket_name, self.filename).load()
- error = exc.value.response["Error"]
- assert error["Code"] == "404"
- assert error["Message"] == "Not Found"
- assert not (upload_dir / self.filename_extra).exists()
- with pytest.raises(ClientError) as exc:
- self.s3_client.Object(self.Storage.bucket_name, self.filename_extra).load()
- error = exc.value.response["Error"]
- assert error["Code"] == "404"
- assert error["Message"] == "Not Found"
- self.Storage.delete_all_files()
- assert not (upload_dir / self.filename).exists()
- assert not (upload_dir / self.filename_extra).exists()
- class TestGCSStorageProvider:
- Storage = provider.GCSStorageProvider()
- Storage.bucket_name = "my-bucket"
- file_content = b"test content"
- filename = "test.txt"
- filename_extra = "test_exyta.txt"
- file_bytesio_empty = io.BytesIO()
- @pytest.fixture
- def setup(self):
- host, port = "localhost", 9023
- server = create_server(host, port, in_memory=True)
- server.start()
- os.environ["STORAGE_EMULATOR_HOST"] = f"http://{host}:{port}"
- gcs_client = storage.Client()
- bucket = gcs_client.bucket(self.Storage.bucket_name)
- bucket.create()
- yield gcs_client, bucket
- bucket.delete(force=True)
- server.stop()
-
- def test_upload_file(self, monkeypatch, tmp_path, setup):
- upload_dir = mock_upload_dir(monkeypatch, tmp_path)
- # test error if bucket does not exist
- with pytest.raises(Exception):
- self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
- # creates bucket and test upload_file method, downloads the file and confirms contents
- self.Storage.gcs_client, self.Storage.bucket = setup
- contents, gcs_file_path = self.Storage.upload_file(
- io.BytesIO(self.file_content), self.filename
- )
- object = self.Storage.bucket.get_blob(self.filename)
- assert self.file_content == object.download_as_bytes()
- # local checks
- assert (upload_dir / self.filename).exists()
- assert (upload_dir / self.filename).read_bytes() == self.file_content
- assert contents == self.file_content
- assert gcs_file_path == "gs://" + self.Storage.bucket_name + "/" + self.filename
- # test error if file is empty
- with pytest.raises(ValueError):
- self.Storage.upload_file(self.file_bytesio_empty, self.filename)
- def test_get_file(self, monkeypatch, tmp_path, setup):
- upload_dir = mock_upload_dir(monkeypatch, tmp_path)
- self.Storage.gcs_client, self.Storage.bucket = setup
- contents, gcs_file_path = self.Storage.upload_file(
- io.BytesIO(self.file_content), self.filename
- )
- file_path = self.Storage.get_file(gcs_file_path)
- assert file_path == str(upload_dir / self.filename)
- assert (upload_dir / self.filename).exists()
- def test_delete_file(self, monkeypatch, tmp_path, setup):
- upload_dir = mock_upload_dir(monkeypatch, tmp_path)
- self.Storage.gcs_client, self.Storage.bucket = setup
- contents, gcs_file_path = self.Storage.upload_file(
- io.BytesIO(self.file_content), self.filename
- )
- # ensure that local directory has the uploaded file as well
- assert (upload_dir / self.filename).exists()
- assert self.Storage.bucket.get_blob(self.filename).name == self.filename
- self.Storage.delete_file(gcs_file_path)
- # check that deleting file from gcs will delete the local file as well
- assert not (upload_dir / self.filename).exists()
- assert self.Storage.bucket.get_blob(self.filename) == None
- def test_delete_all_files(self, monkeypatch, tmp_path, setup):
- upload_dir = mock_upload_dir(monkeypatch, tmp_path)
- self.Storage.gcs_client, self.Storage.bucket = setup
- # create 2 files
- self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
- object = self.Storage.bucket.get_blob(self.filename)
- assert (upload_dir / self.filename).exists()
- assert (upload_dir / self.filename).read_bytes() == self.file_content
- assert self.Storage.bucket.get_blob(self.filename).name == self.filename
- assert self.file_content == object.download_as_bytes()
- self.Storage.upload_file(io.BytesIO(self.file_content), self.filename_extra)
- object = self.Storage.bucket.get_blob(self.filename_extra)
- assert (upload_dir / self.filename_extra).exists()
- assert (upload_dir / self.filename_extra).read_bytes() == self.file_content
- assert self.Storage.bucket.get_blob(self.filename_extra).name == self.filename_extra
- assert self.file_content == object.download_as_bytes()
- self.Storage.delete_all_files()
- assert not (upload_dir / self.filename).exists()
- assert not (upload_dir / self.filename_extra).exists()
- assert self.Storage.bucket.get_blob(self.filename) == None
- assert self.Storage.bucket.get_blob(self.filename_extra) == None
|