test_provider.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. import io
  2. import boto3
  3. import pytest
  4. from botocore.exceptions import ClientError
  5. from moto import mock_aws
  6. from open_webui.storage import provider
  7. def mock_upload_dir(monkeypatch, tmp_path):
  8. """Fixture to monkey-patch the UPLOAD_DIR and create a temporary directory."""
  9. directory = tmp_path / "uploads"
  10. directory.mkdir()
  11. monkeypatch.setattr(provider, "UPLOAD_DIR", str(directory))
  12. return directory
  13. def test_imports():
  14. provider.StorageProvider
  15. provider.LocalStorageProvider
  16. provider.S3StorageProvider
  17. provider.Storage
  18. def test_get_storage_provider():
  19. Storage = provider.get_storage_provider("local")
  20. assert isinstance(Storage, provider.LocalStorageProvider)
  21. Storage = provider.get_storage_provider("s3")
  22. assert isinstance(Storage, provider.S3StorageProvider)
  23. with pytest.raises(RuntimeError):
  24. provider.get_storage_provider("invalid")
  25. def test_class_instantiation():
  26. with pytest.raises(TypeError):
  27. provider.StorageProvider()
  28. with pytest.raises(TypeError):
  29. class Test(provider.StorageProvider):
  30. pass
  31. Test()
  32. provider.LocalStorageProvider()
  33. provider.S3StorageProvider()
  34. class TestLocalStorageProvider:
  35. Storage = provider.LocalStorageProvider()
  36. file_content = b"test content"
  37. file_bytesio = io.BytesIO(file_content)
  38. filename = "test.txt"
  39. filename_extra = "test_exyta.txt"
  40. file_bytesio_empty = io.BytesIO()
  41. def test_upload_file(self, monkeypatch, tmp_path):
  42. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  43. contents, file_path = self.Storage.upload_file(self.file_bytesio, self.filename)
  44. assert (upload_dir / self.filename).exists()
  45. assert (upload_dir / self.filename).read_bytes() == self.file_content
  46. assert contents == self.file_content
  47. assert file_path == str(upload_dir / self.filename)
  48. with pytest.raises(ValueError):
  49. self.Storage.upload_file(self.file_bytesio_empty, self.filename)
  50. def test_get_file(self, monkeypatch, tmp_path):
  51. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  52. file_path = str(upload_dir / self.filename)
  53. file_path_return = self.Storage.get_file(file_path)
  54. assert file_path == file_path_return
  55. def test_delete_file(self, monkeypatch, tmp_path):
  56. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  57. (upload_dir / self.filename).write_bytes(self.file_content)
  58. assert (upload_dir / self.filename).exists()
  59. file_path = str(upload_dir / self.filename)
  60. self.Storage.delete_file(file_path)
  61. assert not (upload_dir / self.filename).exists()
  62. def test_delete_all_files(self, monkeypatch, tmp_path):
  63. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  64. (upload_dir / self.filename).write_bytes(self.file_content)
  65. (upload_dir / self.filename_extra).write_bytes(self.file_content)
  66. self.Storage.delete_all_files()
  67. assert not (upload_dir / self.filename).exists()
  68. assert not (upload_dir / self.filename_extra).exists()
  69. @mock_aws
  70. class TestS3StorageProvider:
  71. Storage = provider.S3StorageProvider()
  72. Storage.bucket_name = "my-bucket"
  73. s3_client = boto3.resource("s3", region_name="us-east-1")
  74. file_content = b"test content"
  75. filename = "test.txt"
  76. filename_extra = "test_exyta.txt"
  77. file_bytesio_empty = io.BytesIO()
  78. def test_upload_file(self, monkeypatch, tmp_path):
  79. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  80. # S3 checks
  81. with pytest.raises(Exception):
  82. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  83. self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
  84. contents, s3_file_path = self.Storage.upload_file(
  85. io.BytesIO(self.file_content), self.filename
  86. )
  87. object = self.s3_client.Object(self.Storage.bucket_name, self.filename)
  88. assert self.file_content == object.get()["Body"].read()
  89. # local checks
  90. assert (upload_dir / self.filename).exists()
  91. assert (upload_dir / self.filename).read_bytes() == self.file_content
  92. assert contents == self.file_content
  93. assert s3_file_path == "s3://" + self.Storage.bucket_name + "/" + self.filename
  94. with pytest.raises(ValueError):
  95. self.Storage.upload_file(self.file_bytesio_empty, self.filename)
  96. def test_get_file(self, monkeypatch, tmp_path):
  97. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  98. self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
  99. contents, s3_file_path = self.Storage.upload_file(
  100. io.BytesIO(self.file_content), self.filename
  101. )
  102. file_path = self.Storage.get_file(s3_file_path)
  103. assert file_path == str(upload_dir / self.filename)
  104. assert (upload_dir / self.filename).exists()
  105. def test_delete_file(self, monkeypatch, tmp_path):
  106. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  107. self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
  108. contents, s3_file_path = self.Storage.upload_file(
  109. io.BytesIO(self.file_content), self.filename
  110. )
  111. assert (upload_dir / self.filename).exists()
  112. self.Storage.delete_file(s3_file_path)
  113. assert not (upload_dir / self.filename).exists()
  114. with pytest.raises(ClientError) as exc:
  115. self.s3_client.Object(self.Storage.bucket_name, self.filename).load()
  116. error = exc.value.response["Error"]
  117. assert error["Code"] == "404"
  118. assert error["Message"] == "Not Found"
  119. def test_delete_all_files(self, monkeypatch, tmp_path):
  120. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  121. # create 2 files
  122. self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
  123. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  124. object = self.s3_client.Object(self.Storage.bucket_name, self.filename)
  125. assert self.file_content == object.get()["Body"].read()
  126. assert (upload_dir / self.filename).exists()
  127. assert (upload_dir / self.filename).read_bytes() == self.file_content
  128. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename_extra)
  129. object = self.s3_client.Object(self.Storage.bucket_name, self.filename_extra)
  130. assert self.file_content == object.get()["Body"].read()
  131. assert (upload_dir / self.filename).exists()
  132. assert (upload_dir / self.filename).read_bytes() == self.file_content
  133. self.Storage.delete_all_files()
  134. assert not (upload_dir / self.filename).exists()
  135. with pytest.raises(ClientError) as exc:
  136. self.s3_client.Object(self.Storage.bucket_name, self.filename).load()
  137. error = exc.value.response["Error"]
  138. assert error["Code"] == "404"
  139. assert error["Message"] == "Not Found"
  140. assert not (upload_dir / self.filename_extra).exists()
  141. with pytest.raises(ClientError) as exc:
  142. self.s3_client.Object(self.Storage.bucket_name, self.filename_extra).load()
  143. error = exc.value.response["Error"]
  144. assert error["Code"] == "404"
  145. assert error["Message"] == "Not Found"
  146. self.Storage.delete_all_files()
  147. assert not (upload_dir / self.filename).exists()
  148. assert not (upload_dir / self.filename_extra).exists()