test_provider.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. import io
  2. import os
  3. import boto3
  4. import pytest
  5. from botocore.exceptions import ClientError
  6. from moto import mock_aws
  7. from open_webui.storage import provider
  8. from gcp_storage_emulator.server import create_server
  9. from google.cloud import storage
  10. from azure.storage.blob import BlobServiceClient, ContainerClient, BlobClient
  11. from unittest.mock import MagicMock
  12. def mock_upload_dir(monkeypatch, tmp_path):
  13. """Fixture to monkey-patch the UPLOAD_DIR and create a temporary directory."""
  14. directory = tmp_path / "uploads"
  15. directory.mkdir()
  16. monkeypatch.setattr(provider, "UPLOAD_DIR", str(directory))
  17. return directory
  18. def test_imports():
  19. provider.StorageProvider
  20. provider.LocalStorageProvider
  21. provider.S3StorageProvider
  22. provider.GCSStorageProvider
  23. provider.AzureStorageProvider
  24. provider.Storage
  25. def test_get_storage_provider():
  26. Storage = provider.get_storage_provider("local")
  27. assert isinstance(Storage, provider.LocalStorageProvider)
  28. Storage = provider.get_storage_provider("s3")
  29. assert isinstance(Storage, provider.S3StorageProvider)
  30. Storage = provider.get_storage_provider("gcs")
  31. assert isinstance(Storage, provider.GCSStorageProvider)
  32. Storage = provider.get_storage_provider("azure")
  33. assert isinstance(Storage, provider.AzureStorageProvider)
  34. with pytest.raises(RuntimeError):
  35. provider.get_storage_provider("invalid")
  36. def test_class_instantiation():
  37. with pytest.raises(TypeError):
  38. provider.StorageProvider()
  39. with pytest.raises(TypeError):
  40. class Test(provider.StorageProvider):
  41. pass
  42. Test()
  43. provider.LocalStorageProvider()
  44. provider.S3StorageProvider()
  45. provider.GCSStorageProvider()
  46. provider.AzureStorageProvider()
  47. class TestLocalStorageProvider:
  48. Storage = provider.LocalStorageProvider()
  49. file_content = b"test content"
  50. file_bytesio = io.BytesIO(file_content)
  51. filename = "test.txt"
  52. filename_extra = "test_exyta.txt"
  53. file_bytesio_empty = io.BytesIO()
  54. def test_upload_file(self, monkeypatch, tmp_path):
  55. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  56. contents, file_path = self.Storage.upload_file(self.file_bytesio, self.filename)
  57. assert (upload_dir / self.filename).exists()
  58. assert (upload_dir / self.filename).read_bytes() == self.file_content
  59. assert contents == self.file_content
  60. assert file_path == str(upload_dir / self.filename)
  61. with pytest.raises(ValueError):
  62. self.Storage.upload_file(self.file_bytesio_empty, self.filename)
  63. def test_get_file(self, monkeypatch, tmp_path):
  64. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  65. file_path = str(upload_dir / self.filename)
  66. file_path_return = self.Storage.get_file(file_path)
  67. assert file_path == file_path_return
  68. def test_delete_file(self, monkeypatch, tmp_path):
  69. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  70. (upload_dir / self.filename).write_bytes(self.file_content)
  71. assert (upload_dir / self.filename).exists()
  72. file_path = str(upload_dir / self.filename)
  73. self.Storage.delete_file(file_path)
  74. assert not (upload_dir / self.filename).exists()
  75. def test_delete_all_files(self, monkeypatch, tmp_path):
  76. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  77. (upload_dir / self.filename).write_bytes(self.file_content)
  78. (upload_dir / self.filename_extra).write_bytes(self.file_content)
  79. self.Storage.delete_all_files()
  80. assert not (upload_dir / self.filename).exists()
  81. assert not (upload_dir / self.filename_extra).exists()
  82. @mock_aws
  83. class TestS3StorageProvider:
  84. def __init__(self):
  85. self.Storage = provider.S3StorageProvider()
  86. self.Storage.bucket_name = "my-bucket"
  87. self.s3_client = boto3.resource("s3", region_name="us-east-1")
  88. self.file_content = b"test content"
  89. self.filename = "test.txt"
  90. self.filename_extra = "test_exyta.txt"
  91. self.file_bytesio_empty = io.BytesIO()
  92. super().__init__()
  93. def test_upload_file(self, monkeypatch, tmp_path):
  94. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  95. # S3 checks
  96. with pytest.raises(Exception):
  97. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  98. self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
  99. contents, s3_file_path = self.Storage.upload_file(
  100. io.BytesIO(self.file_content), self.filename
  101. )
  102. object = self.s3_client.Object(self.Storage.bucket_name, self.filename)
  103. assert self.file_content == object.get()["Body"].read()
  104. # local checks
  105. assert (upload_dir / self.filename).exists()
  106. assert (upload_dir / self.filename).read_bytes() == self.file_content
  107. assert contents == self.file_content
  108. assert s3_file_path == "s3://" + self.Storage.bucket_name + "/" + self.filename
  109. with pytest.raises(ValueError):
  110. self.Storage.upload_file(self.file_bytesio_empty, self.filename)
  111. def test_get_file(self, monkeypatch, tmp_path):
  112. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  113. self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
  114. contents, s3_file_path = self.Storage.upload_file(
  115. io.BytesIO(self.file_content), self.filename
  116. )
  117. file_path = self.Storage.get_file(s3_file_path)
  118. assert file_path == str(upload_dir / self.filename)
  119. assert (upload_dir / self.filename).exists()
  120. def test_delete_file(self, monkeypatch, tmp_path):
  121. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  122. self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
  123. contents, s3_file_path = self.Storage.upload_file(
  124. io.BytesIO(self.file_content), self.filename
  125. )
  126. assert (upload_dir / self.filename).exists()
  127. self.Storage.delete_file(s3_file_path)
  128. assert not (upload_dir / self.filename).exists()
  129. with pytest.raises(ClientError) as exc:
  130. self.s3_client.Object(self.Storage.bucket_name, self.filename).load()
  131. error = exc.value.response["Error"]
  132. assert error["Code"] == "404"
  133. assert error["Message"] == "Not Found"
  134. def test_delete_all_files(self, monkeypatch, tmp_path):
  135. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  136. # create 2 files
  137. self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
  138. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  139. object = self.s3_client.Object(self.Storage.bucket_name, self.filename)
  140. assert self.file_content == object.get()["Body"].read()
  141. assert (upload_dir / self.filename).exists()
  142. assert (upload_dir / self.filename).read_bytes() == self.file_content
  143. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename_extra)
  144. object = self.s3_client.Object(self.Storage.bucket_name, self.filename_extra)
  145. assert self.file_content == object.get()["Body"].read()
  146. assert (upload_dir / self.filename).exists()
  147. assert (upload_dir / self.filename).read_bytes() == self.file_content
  148. self.Storage.delete_all_files()
  149. assert not (upload_dir / self.filename).exists()
  150. with pytest.raises(ClientError) as exc:
  151. self.s3_client.Object(self.Storage.bucket_name, self.filename).load()
  152. error = exc.value.response["Error"]
  153. assert error["Code"] == "404"
  154. assert error["Message"] == "Not Found"
  155. assert not (upload_dir / self.filename_extra).exists()
  156. with pytest.raises(ClientError) as exc:
  157. self.s3_client.Object(self.Storage.bucket_name, self.filename_extra).load()
  158. error = exc.value.response["Error"]
  159. assert error["Code"] == "404"
  160. assert error["Message"] == "Not Found"
  161. self.Storage.delete_all_files()
  162. assert not (upload_dir / self.filename).exists()
  163. assert not (upload_dir / self.filename_extra).exists()
  164. class TestGCSStorageProvider:
  165. Storage = provider.GCSStorageProvider()
  166. Storage.bucket_name = "my-bucket"
  167. file_content = b"test content"
  168. filename = "test.txt"
  169. filename_extra = "test_exyta.txt"
  170. file_bytesio_empty = io.BytesIO()
  171. @pytest.fixture(scope="class")
  172. def setup(self):
  173. host, port = "localhost", 9023
  174. server = create_server(host, port, in_memory=True)
  175. server.start()
  176. os.environ["STORAGE_EMULATOR_HOST"] = f"http://{host}:{port}"
  177. gcs_client = storage.Client()
  178. bucket = gcs_client.bucket(self.Storage.bucket_name)
  179. bucket.create()
  180. self.Storage.gcs_client, self.Storage.bucket = gcs_client, bucket
  181. yield
  182. bucket.delete(force=True)
  183. server.stop()
  184. def test_upload_file(self, monkeypatch, tmp_path, setup):
  185. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  186. # catch error if bucket does not exist
  187. with pytest.raises(Exception):
  188. self.Storage.bucket = monkeypatch(self.Storage, "bucket", None)
  189. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  190. contents, gcs_file_path = self.Storage.upload_file(
  191. io.BytesIO(self.file_content), self.filename
  192. )
  193. object = self.Storage.bucket.get_blob(self.filename)
  194. assert self.file_content == object.download_as_bytes()
  195. # local checks
  196. assert (upload_dir / self.filename).exists()
  197. assert (upload_dir / self.filename).read_bytes() == self.file_content
  198. assert contents == self.file_content
  199. assert gcs_file_path == "gs://" + self.Storage.bucket_name + "/" + self.filename
  200. # test error if file is empty
  201. with pytest.raises(ValueError):
  202. self.Storage.upload_file(self.file_bytesio_empty, self.filename)
  203. def test_get_file(self, monkeypatch, tmp_path, setup):
  204. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  205. contents, gcs_file_path = self.Storage.upload_file(
  206. io.BytesIO(self.file_content), self.filename
  207. )
  208. file_path = self.Storage.get_file(gcs_file_path)
  209. assert file_path == str(upload_dir / self.filename)
  210. assert (upload_dir / self.filename).exists()
  211. def test_delete_file(self, monkeypatch, tmp_path, setup):
  212. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  213. contents, gcs_file_path = self.Storage.upload_file(
  214. io.BytesIO(self.file_content), self.filename
  215. )
  216. # ensure that local directory has the uploaded file as well
  217. assert (upload_dir / self.filename).exists()
  218. assert self.Storage.bucket.get_blob(self.filename).name == self.filename
  219. self.Storage.delete_file(gcs_file_path)
  220. # check that deleting file from gcs will delete the local file as well
  221. assert not (upload_dir / self.filename).exists()
  222. assert self.Storage.bucket.get_blob(self.filename) == None
  223. def test_delete_all_files(self, monkeypatch, tmp_path, setup):
  224. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  225. # create 2 files
  226. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  227. object = self.Storage.bucket.get_blob(self.filename)
  228. assert (upload_dir / self.filename).exists()
  229. assert (upload_dir / self.filename).read_bytes() == self.file_content
  230. assert self.Storage.bucket.get_blob(self.filename).name == self.filename
  231. assert self.file_content == object.download_as_bytes()
  232. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename_extra)
  233. object = self.Storage.bucket.get_blob(self.filename_extra)
  234. assert (upload_dir / self.filename_extra).exists()
  235. assert (upload_dir / self.filename_extra).read_bytes() == self.file_content
  236. assert (
  237. self.Storage.bucket.get_blob(self.filename_extra).name
  238. == self.filename_extra
  239. )
  240. assert self.file_content == object.download_as_bytes()
  241. self.Storage.delete_all_files()
  242. assert not (upload_dir / self.filename).exists()
  243. assert not (upload_dir / self.filename_extra).exists()
  244. assert self.Storage.bucket.get_blob(self.filename) == None
  245. assert self.Storage.bucket.get_blob(self.filename_extra) == None
  246. class TestAzureStorageProvider:
  247. def __init__(self):
  248. super().__init__()
  249. @pytest.fixture(scope="class")
  250. def setup_storage(self, monkeypatch):
  251. # Create mock Blob Service Client and related clients
  252. mock_blob_service_client = MagicMock()
  253. mock_container_client = MagicMock()
  254. mock_blob_client = MagicMock()
  255. # Set up return values for the mock
  256. mock_blob_service_client.get_container_client.return_value = (
  257. mock_container_client
  258. )
  259. mock_container_client.get_blob_client.return_value = mock_blob_client
  260. # Monkeypatch the Azure classes to return our mocks
  261. monkeypatch.setattr(
  262. azure.storage.blob,
  263. "BlobServiceClient",
  264. lambda *args, **kwargs: mock_blob_service_client,
  265. )
  266. monkeypatch.setattr(
  267. azure.storage.blob,
  268. "ContainerClient",
  269. lambda *args, **kwargs: mock_container_client,
  270. )
  271. monkeypatch.setattr(
  272. azure.storage.blob, "BlobClient", lambda *args, **kwargs: mock_blob_client
  273. )
  274. self.Storage = provider.AzureStorageProvider()
  275. self.Storage.endpoint = "https://myaccount.blob.core.windows.net"
  276. self.Storage.container_name = "my-container"
  277. self.file_content = b"test content"
  278. self.filename = "test.txt"
  279. self.filename_extra = "test_extra.txt"
  280. self.file_bytesio_empty = io.BytesIO()
  281. # Apply mocks to the Storage instance
  282. self.Storage.blob_service_client = mock_blob_service_client
  283. self.Storage.container_client = mock_container_client
  284. def test_upload_file(self, monkeypatch, tmp_path):
  285. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  286. # Simulate an error when container does not exist
  287. self.Storage.container_client.get_blob_client.side_effect = Exception(
  288. "Container does not exist"
  289. )
  290. with pytest.raises(Exception):
  291. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  292. # Reset side effect and create container
  293. self.Storage.container_client.get_blob_client.side_effect = None
  294. self.Storage.create_container()
  295. contents, azure_file_path = self.Storage.upload_file(
  296. io.BytesIO(self.file_content), self.filename
  297. )
  298. # Assertions
  299. self.Storage.container_client.get_blob_client.assert_called_with(self.filename)
  300. self.Storage.container_client.get_blob_client().upload_blob.assert_called_once_with(
  301. self.file_content, overwrite=True
  302. )
  303. assert contents == self.file_content
  304. assert (
  305. azure_file_path
  306. == f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
  307. )
  308. assert (upload_dir / self.filename).exists()
  309. assert (upload_dir / self.filename).read_bytes() == self.file_content
  310. with pytest.raises(ValueError):
  311. self.Storage.upload_file(self.file_bytesio_empty, self.filename)
  312. def test_get_file(self, monkeypatch, tmp_path):
  313. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  314. self.Storage.create_container()
  315. # Mock upload behavior
  316. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  317. # Mock blob download behavior
  318. self.Storage.container_client.get_blob_client().download_blob().readall.return_value = (
  319. self.file_content
  320. )
  321. file_url = f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
  322. file_path = self.Storage.get_file(file_url)
  323. assert file_path == str(upload_dir / self.filename)
  324. assert (upload_dir / self.filename).exists()
  325. assert (upload_dir / self.filename).read_bytes() == self.file_content
  326. def test_delete_file(self, monkeypatch, tmp_path):
  327. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  328. self.Storage.create_container()
  329. # Mock file upload
  330. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  331. # Mock deletion
  332. self.Storage.container_client.get_blob_client().delete_blob.return_value = None
  333. file_url = f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
  334. self.Storage.delete_file(file_url)
  335. self.Storage.container_client.get_blob_client().delete_blob.assert_called_once()
  336. assert not (upload_dir / self.filename).exists()
  337. def test_delete_all_files(self, monkeypatch, tmp_path):
  338. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  339. self.Storage.create_container()
  340. # Mock file uploads
  341. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  342. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename_extra)
  343. # Mock listing and deletion behavior
  344. self.Storage.container_client.list_blobs.return_value = [
  345. {"name": self.filename},
  346. {"name": self.filename_extra},
  347. ]
  348. self.Storage.container_client.get_blob_client().delete_blob.return_value = None
  349. self.Storage.delete_all_files()
  350. self.Storage.container_client.list_blobs.assert_called_once()
  351. self.Storage.container_client.get_blob_client().delete_blob.assert_any_call()
  352. assert not (upload_dir / self.filename).exists()
  353. assert not (upload_dir / self.filename_extra).exists()
  354. def test_get_file_not_found(self, monkeypatch):
  355. self.Storage.create_container()
  356. file_url = f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
  357. # Mock behavior to raise an error for missing blobs
  358. self.Storage.container_client.get_blob_client().download_blob.side_effect = (
  359. Exception("Blob not found")
  360. )
  361. with pytest.raises(Exception, match="Blob not found"):
  362. self.Storage.get_file(file_url)