test_provider.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. import io
  2. import boto3
  3. import pytest
  4. from botocore.exceptions import ClientError
  5. from moto import mock_aws
  6. from open_webui.storage import provider
  7. def mock_upload_dir(monkeypatch, tmp_path):
  8. """Fixture to monkey-patch the UPLOAD_DIR and create a temporary directory."""
  9. directory = tmp_path / "uploads"
  10. directory.mkdir()
  11. monkeypatch.setattr(provider, "UPLOAD_DIR", str(directory))
  12. return directory
  13. def test_imports():
  14. provider.StorageProvider
  15. provider.LocalStorageProvider
  16. provider.S3StorageProvider
  17. provider.Storage
  18. def test_get_storage_provider():
  19. Storage = provider.get_storage_provider("local")
  20. assert isinstance(Storage, provider.LocalStorageProvider)
  21. Storage = provider.get_storage_provider("s3")
  22. assert isinstance(Storage, provider.S3StorageProvider)
  23. with pytest.raises(RuntimeError):
  24. provider.get_storage_provider("invalid")
  25. def test_class_instantiation():
  26. with pytest.raises(TypeError):
  27. provider.StorageProvider()
  28. with pytest.raises(TypeError):
  29. class Test(provider.StorageProvider):
  30. pass
  31. Test()
  32. provider.LocalStorageProvider()
  33. provider.S3StorageProvider()
  34. class TestLocalStorageProvider:
  35. Storage = provider.LocalStorageProvider()
  36. file_content = b"test content"
  37. file_bytesio = io.BytesIO(file_content)
  38. filename = "test.txt"
  39. filename_extra = "test_exyta.txt"
  40. file_bytesio_empty = io.BytesIO()
  41. def test_upload_file(self, monkeypatch, tmp_path):
  42. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  43. contents, file_path = self.Storage.upload_file(self.file_bytesio, self.filename)
  44. assert (upload_dir / self.filename).exists()
  45. assert (upload_dir / self.filename).read_bytes() == self.file_content
  46. assert contents == self.file_content
  47. assert file_path == str(upload_dir / self.filename)
  48. with pytest.raises(ValueError):
  49. self.Storage.upload_file(self.file_bytesio_empty, self.filename)
  50. def test_get_file(self, monkeypatch, tmp_path):
  51. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  52. file_path = str(upload_dir / self.filename)
  53. file_path_return = self.Storage.get_file(file_path)
  54. assert file_path == file_path_return
  55. def test_delete_file(self, monkeypatch, tmp_path):
  56. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  57. (upload_dir / self.filename).write_bytes(self.file_content)
  58. assert (upload_dir / self.filename).exists()
  59. file_path = str(upload_dir / self.filename)
  60. self.Storage.delete_file(file_path)
  61. assert not (upload_dir / self.filename).exists()
  62. def test_delete_all_files(self, monkeypatch, tmp_path):
  63. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  64. (upload_dir / self.filename).write_bytes(self.file_content)
  65. (upload_dir / self.filename_extra).write_bytes(self.file_content)
  66. self.Storage.delete_all_files()
  67. assert not (upload_dir / self.filename).exists()
  68. assert not (upload_dir / self.filename_extra).exists()
  69. @mock_aws
  70. class TestS3StorageProvider:
  71. def __init__(self):
  72. self.Storage = provider.S3StorageProvider()
  73. self.Storage.bucket_name = "my-bucket"
  74. self.s3_client = boto3.resource("s3", region_name="us-east-1")
  75. self.file_content = b"test content"
  76. self.filename = "test.txt"
  77. self.filename_extra = "test_exyta.txt"
  78. self.file_bytesio_empty = io.BytesIO()
  79. super().__init__()
  80. def test_upload_file(self, monkeypatch, tmp_path):
  81. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  82. # S3 checks
  83. with pytest.raises(Exception):
  84. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  85. self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
  86. contents, s3_file_path = self.Storage.upload_file(
  87. io.BytesIO(self.file_content), self.filename
  88. )
  89. object = self.s3_client.Object(self.Storage.bucket_name, self.filename)
  90. assert self.file_content == object.get()["Body"].read()
  91. # local checks
  92. assert (upload_dir / self.filename).exists()
  93. assert (upload_dir / self.filename).read_bytes() == self.file_content
  94. assert contents == self.file_content
  95. assert s3_file_path == "s3://" + self.Storage.bucket_name + "/" + self.filename
  96. with pytest.raises(ValueError):
  97. self.Storage.upload_file(self.file_bytesio_empty, self.filename)
  98. def test_get_file(self, monkeypatch, tmp_path):
  99. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  100. self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
  101. contents, s3_file_path = self.Storage.upload_file(
  102. io.BytesIO(self.file_content), self.filename
  103. )
  104. file_path = self.Storage.get_file(s3_file_path)
  105. assert file_path == str(upload_dir / self.filename)
  106. assert (upload_dir / self.filename).exists()
  107. def test_delete_file(self, monkeypatch, tmp_path):
  108. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  109. self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
  110. contents, s3_file_path = self.Storage.upload_file(
  111. io.BytesIO(self.file_content), self.filename
  112. )
  113. assert (upload_dir / self.filename).exists()
  114. self.Storage.delete_file(s3_file_path)
  115. assert not (upload_dir / self.filename).exists()
  116. with pytest.raises(ClientError) as exc:
  117. self.s3_client.Object(self.Storage.bucket_name, self.filename).load()
  118. error = exc.value.response["Error"]
  119. assert error["Code"] == "404"
  120. assert error["Message"] == "Not Found"
  121. def test_delete_all_files(self, monkeypatch, tmp_path):
  122. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  123. # create 2 files
  124. self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
  125. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  126. object = self.s3_client.Object(self.Storage.bucket_name, self.filename)
  127. assert self.file_content == object.get()["Body"].read()
  128. assert (upload_dir / self.filename).exists()
  129. assert (upload_dir / self.filename).read_bytes() == self.file_content
  130. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename_extra)
  131. object = self.s3_client.Object(self.Storage.bucket_name, self.filename_extra)
  132. assert self.file_content == object.get()["Body"].read()
  133. assert (upload_dir / self.filename).exists()
  134. assert (upload_dir / self.filename).read_bytes() == self.file_content
  135. self.Storage.delete_all_files()
  136. assert not (upload_dir / self.filename).exists()
  137. with pytest.raises(ClientError) as exc:
  138. self.s3_client.Object(self.Storage.bucket_name, self.filename).load()
  139. error = exc.value.response["Error"]
  140. assert error["Code"] == "404"
  141. assert error["Message"] == "Not Found"
  142. assert not (upload_dir / self.filename_extra).exists()
  143. with pytest.raises(ClientError) as exc:
  144. self.s3_client.Object(self.Storage.bucket_name, self.filename_extra).load()
  145. error = exc.value.response["Error"]
  146. assert error["Code"] == "404"
  147. assert error["Message"] == "Not Found"
  148. self.Storage.delete_all_files()
  149. assert not (upload_dir / self.filename).exists()
  150. assert not (upload_dir / self.filename_extra).exists()