test_provider.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. import io
  2. import os
  3. import boto3
  4. import pytest
  5. from botocore.exceptions import ClientError
  6. from moto import mock_aws
  7. from open_webui.storage import provider
  8. from gcp_storage_emulator.server import create_server
  9. from google.cloud import storage
  10. def mock_upload_dir(monkeypatch, tmp_path):
  11. """Fixture to monkey-patch the UPLOAD_DIR and create a temporary directory."""
  12. directory = tmp_path / "uploads"
  13. directory.mkdir()
  14. monkeypatch.setattr(provider, "UPLOAD_DIR", str(directory))
  15. return directory
  16. def test_imports():
  17. provider.StorageProvider
  18. provider.LocalStorageProvider
  19. provider.S3StorageProvider
  20. provider.GCSStorageProvider
  21. provider.AzureStorageProvider
  22. provider.Storage
  23. def test_get_storage_provider():
  24. Storage = provider.get_storage_provider("local")
  25. assert isinstance(Storage, provider.LocalStorageProvider)
  26. Storage = provider.get_storage_provider("s3")
  27. assert isinstance(Storage, provider.S3StorageProvider)
  28. Storage = provider.get_storage_provider("gcs")
  29. assert isinstance(Storage, provider.GCSStorageProvider)
  30. Storage = provider.get_storage_provider("azure")
  31. assert isinstance(Storage, provider.AzureStorageProvider)
  32. with pytest.raises(RuntimeError):
  33. provider.get_storage_provider("invalid")
  34. def test_class_instantiation():
  35. with pytest.raises(TypeError):
  36. provider.StorageProvider()
  37. with pytest.raises(TypeError):
  38. class Test(provider.StorageProvider):
  39. pass
  40. Test()
  41. provider.LocalStorageProvider()
  42. provider.S3StorageProvider()
  43. provider.GCSStorageProvider()
  44. provider.AzureStorageProvider()
  45. class TestLocalStorageProvider:
  46. Storage = provider.LocalStorageProvider()
  47. file_content = b"test content"
  48. file_bytesio = io.BytesIO(file_content)
  49. filename = "test.txt"
  50. filename_extra = "test_exyta.txt"
  51. file_bytesio_empty = io.BytesIO()
  52. def test_upload_file(self, monkeypatch, tmp_path):
  53. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  54. contents, file_path = self.Storage.upload_file(self.file_bytesio, self.filename)
  55. assert (upload_dir / self.filename).exists()
  56. assert (upload_dir / self.filename).read_bytes() == self.file_content
  57. assert contents == self.file_content
  58. assert file_path == str(upload_dir / self.filename)
  59. with pytest.raises(ValueError):
  60. self.Storage.upload_file(self.file_bytesio_empty, self.filename)
  61. def test_get_file(self, monkeypatch, tmp_path):
  62. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  63. file_path = str(upload_dir / self.filename)
  64. file_path_return = self.Storage.get_file(file_path)
  65. assert file_path == file_path_return
  66. def test_delete_file(self, monkeypatch, tmp_path):
  67. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  68. (upload_dir / self.filename).write_bytes(self.file_content)
  69. assert (upload_dir / self.filename).exists()
  70. file_path = str(upload_dir / self.filename)
  71. self.Storage.delete_file(file_path)
  72. assert not (upload_dir / self.filename).exists()
  73. def test_delete_all_files(self, monkeypatch, tmp_path):
  74. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  75. (upload_dir / self.filename).write_bytes(self.file_content)
  76. (upload_dir / self.filename_extra).write_bytes(self.file_content)
  77. self.Storage.delete_all_files()
  78. assert not (upload_dir / self.filename).exists()
  79. assert not (upload_dir / self.filename_extra).exists()
  80. @mock_aws
  81. class TestS3StorageProvider:
  82. def __init__(self):
  83. self.Storage = provider.S3StorageProvider()
  84. self.Storage.bucket_name = "my-bucket"
  85. self.s3_client = boto3.resource("s3", region_name="us-east-1")
  86. self.file_content = b"test content"
  87. self.filename = "test.txt"
  88. self.filename_extra = "test_exyta.txt"
  89. self.file_bytesio_empty = io.BytesIO()
  90. super().__init__()
  91. def test_upload_file(self, monkeypatch, tmp_path):
  92. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  93. # S3 checks
  94. with pytest.raises(Exception):
  95. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  96. self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
  97. contents, s3_file_path = self.Storage.upload_file(
  98. io.BytesIO(self.file_content), self.filename
  99. )
  100. object = self.s3_client.Object(self.Storage.bucket_name, self.filename)
  101. assert self.file_content == object.get()["Body"].read()
  102. # local checks
  103. assert (upload_dir / self.filename).exists()
  104. assert (upload_dir / self.filename).read_bytes() == self.file_content
  105. assert contents == self.file_content
  106. assert s3_file_path == "s3://" + self.Storage.bucket_name + "/" + self.filename
  107. with pytest.raises(ValueError):
  108. self.Storage.upload_file(self.file_bytesio_empty, self.filename)
  109. def test_get_file(self, monkeypatch, tmp_path):
  110. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  111. self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
  112. contents, s3_file_path = self.Storage.upload_file(
  113. io.BytesIO(self.file_content), self.filename
  114. )
  115. file_path = self.Storage.get_file(s3_file_path)
  116. assert file_path == str(upload_dir / self.filename)
  117. assert (upload_dir / self.filename).exists()
  118. def test_delete_file(self, monkeypatch, tmp_path):
  119. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  120. self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
  121. contents, s3_file_path = self.Storage.upload_file(
  122. io.BytesIO(self.file_content), self.filename
  123. )
  124. assert (upload_dir / self.filename).exists()
  125. self.Storage.delete_file(s3_file_path)
  126. assert not (upload_dir / self.filename).exists()
  127. with pytest.raises(ClientError) as exc:
  128. self.s3_client.Object(self.Storage.bucket_name, self.filename).load()
  129. error = exc.value.response["Error"]
  130. assert error["Code"] == "404"
  131. assert error["Message"] == "Not Found"
  132. def test_delete_all_files(self, monkeypatch, tmp_path):
  133. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  134. # create 2 files
  135. self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
  136. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  137. object = self.s3_client.Object(self.Storage.bucket_name, self.filename)
  138. assert self.file_content == object.get()["Body"].read()
  139. assert (upload_dir / self.filename).exists()
  140. assert (upload_dir / self.filename).read_bytes() == self.file_content
  141. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename_extra)
  142. object = self.s3_client.Object(self.Storage.bucket_name, self.filename_extra)
  143. assert self.file_content == object.get()["Body"].read()
  144. assert (upload_dir / self.filename).exists()
  145. assert (upload_dir / self.filename).read_bytes() == self.file_content
  146. self.Storage.delete_all_files()
  147. assert not (upload_dir / self.filename).exists()
  148. with pytest.raises(ClientError) as exc:
  149. self.s3_client.Object(self.Storage.bucket_name, self.filename).load()
  150. error = exc.value.response["Error"]
  151. assert error["Code"] == "404"
  152. assert error["Message"] == "Not Found"
  153. assert not (upload_dir / self.filename_extra).exists()
  154. with pytest.raises(ClientError) as exc:
  155. self.s3_client.Object(self.Storage.bucket_name, self.filename_extra).load()
  156. error = exc.value.response["Error"]
  157. assert error["Code"] == "404"
  158. assert error["Message"] == "Not Found"
  159. self.Storage.delete_all_files()
  160. assert not (upload_dir / self.filename).exists()
  161. assert not (upload_dir / self.filename_extra).exists()
  162. class TestGCSStorageProvider:
  163. Storage = provider.GCSStorageProvider()
  164. Storage.bucket_name = "my-bucket"
  165. file_content = b"test content"
  166. filename = "test.txt"
  167. filename_extra = "test_exyta.txt"
  168. file_bytesio_empty = io.BytesIO()
  169. @pytest.fixture(scope="class")
  170. def setup(self):
  171. host, port = "localhost", 9023
  172. server = create_server(host, port, in_memory=True)
  173. server.start()
  174. os.environ["STORAGE_EMULATOR_HOST"] = f"http://{host}:{port}"
  175. gcs_client = storage.Client()
  176. bucket = gcs_client.bucket(self.Storage.bucket_name)
  177. bucket.create()
  178. self.Storage.gcs_client, self.Storage.bucket = gcs_client, bucket
  179. yield
  180. bucket.delete(force=True)
  181. server.stop()
  182. def test_upload_file(self, monkeypatch, tmp_path, setup):
  183. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  184. # catch error if bucket does not exist
  185. with pytest.raises(Exception):
  186. self.Storage.bucket = monkeypatch(self.Storage, "bucket", None)
  187. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  188. contents, gcs_file_path = self.Storage.upload_file(
  189. io.BytesIO(self.file_content), self.filename
  190. )
  191. object = self.Storage.bucket.get_blob(self.filename)
  192. assert self.file_content == object.download_as_bytes()
  193. # local checks
  194. assert (upload_dir / self.filename).exists()
  195. assert (upload_dir / self.filename).read_bytes() == self.file_content
  196. assert contents == self.file_content
  197. assert gcs_file_path == "gs://" + self.Storage.bucket_name + "/" + self.filename
  198. # test error if file is empty
  199. with pytest.raises(ValueError):
  200. self.Storage.upload_file(self.file_bytesio_empty, self.filename)
  201. def test_get_file(self, monkeypatch, tmp_path, setup):
  202. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  203. contents, gcs_file_path = self.Storage.upload_file(
  204. io.BytesIO(self.file_content), self.filename
  205. )
  206. file_path = self.Storage.get_file(gcs_file_path)
  207. assert file_path == str(upload_dir / self.filename)
  208. assert (upload_dir / self.filename).exists()
  209. def test_delete_file(self, monkeypatch, tmp_path, setup):
  210. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  211. contents, gcs_file_path = self.Storage.upload_file(
  212. io.BytesIO(self.file_content), self.filename
  213. )
  214. # ensure that local directory has the uploaded file as well
  215. assert (upload_dir / self.filename).exists()
  216. assert self.Storage.bucket.get_blob(self.filename).name == self.filename
  217. self.Storage.delete_file(gcs_file_path)
  218. # check that deleting file from gcs will delete the local file as well
  219. assert not (upload_dir / self.filename).exists()
  220. assert self.Storage.bucket.get_blob(self.filename) == None
  221. def test_delete_all_files(self, monkeypatch, tmp_path, setup):
  222. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  223. # create 2 files
  224. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  225. object = self.Storage.bucket.get_blob(self.filename)
  226. assert (upload_dir / self.filename).exists()
  227. assert (upload_dir / self.filename).read_bytes() == self.file_content
  228. assert self.Storage.bucket.get_blob(self.filename).name == self.filename
  229. assert self.file_content == object.download_as_bytes()
  230. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename_extra)
  231. object = self.Storage.bucket.get_blob(self.filename_extra)
  232. assert (upload_dir / self.filename_extra).exists()
  233. assert (upload_dir / self.filename_extra).read_bytes() == self.file_content
  234. assert (
  235. self.Storage.bucket.get_blob(self.filename_extra).name
  236. == self.filename_extra
  237. )
  238. assert self.file_content == object.download_as_bytes()
  239. self.Storage.delete_all_files()
  240. assert not (upload_dir / self.filename).exists()
  241. assert not (upload_dir / self.filename_extra).exists()
  242. assert self.Storage.bucket.get_blob(self.filename) == None
  243. assert self.Storage.bucket.get_blob(self.filename_extra) == None