test_provider.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. import io
  2. import os
  3. import boto3
  4. import pytest
  5. from botocore.exceptions import ClientError
  6. from moto import mock_aws
  7. from open_webui.storage import provider
  8. from gcp_storage_emulator.server import create_server
  9. from google.cloud import storage
  10. from azure.storage.blob import BlobServiceClient, ContainerClient, BlobClient
  11. from unittest.mock import MagicMock
  12. def mock_upload_dir(monkeypatch, tmp_path):
  13. """Fixture to monkey-patch the UPLOAD_DIR and create a temporary directory."""
  14. directory = tmp_path / "uploads"
  15. directory.mkdir()
  16. monkeypatch.setattr(provider, "UPLOAD_DIR", str(directory))
  17. return directory
  18. def test_imports():
  19. provider.StorageProvider
  20. provider.LocalStorageProvider
  21. provider.S3StorageProvider
  22. provider.GCSStorageProvider
  23. provider.AzureStorageProvider
  24. provider.Storage
  25. def test_get_storage_provider():
  26. Storage = provider.get_storage_provider("local")
  27. assert isinstance(Storage, provider.LocalStorageProvider)
  28. Storage = provider.get_storage_provider("s3")
  29. assert isinstance(Storage, provider.S3StorageProvider)
  30. Storage = provider.get_storage_provider("gcs")
  31. assert isinstance(Storage, provider.GCSStorageProvider)
  32. Storage = provider.get_storage_provider("azure")
  33. assert isinstance(Storage, provider.AzureStorageProvider)
  34. with pytest.raises(RuntimeError):
  35. provider.get_storage_provider("invalid")
  36. def test_class_instantiation():
  37. with pytest.raises(TypeError):
  38. provider.StorageProvider()
  39. with pytest.raises(TypeError):
  40. class Test(provider.StorageProvider):
  41. pass
  42. Test()
  43. provider.LocalStorageProvider()
  44. provider.S3StorageProvider()
  45. provider.GCSStorageProvider()
  46. provider.AzureStorageProvider()
  47. class TestLocalStorageProvider:
  48. Storage = provider.LocalStorageProvider()
  49. file_content = b"test content"
  50. file_bytesio = io.BytesIO(file_content)
  51. filename = "test.txt"
  52. filename_extra = "test_exyta.txt"
  53. file_bytesio_empty = io.BytesIO()
  54. def test_upload_file(self, monkeypatch, tmp_path):
  55. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  56. contents, file_path = self.Storage.upload_file(self.file_bytesio, self.filename)
  57. assert (upload_dir / self.filename).exists()
  58. assert (upload_dir / self.filename).read_bytes() == self.file_content
  59. assert contents == self.file_content
  60. assert file_path == str(upload_dir / self.filename)
  61. with pytest.raises(ValueError):
  62. self.Storage.upload_file(self.file_bytesio_empty, self.filename)
  63. def test_get_file(self, monkeypatch, tmp_path):
  64. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  65. file_path = str(upload_dir / self.filename)
  66. file_path_return = self.Storage.get_file(file_path)
  67. assert file_path == file_path_return
  68. def test_delete_file(self, monkeypatch, tmp_path):
  69. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  70. (upload_dir / self.filename).write_bytes(self.file_content)
  71. assert (upload_dir / self.filename).exists()
  72. file_path = str(upload_dir / self.filename)
  73. self.Storage.delete_file(file_path)
  74. assert not (upload_dir / self.filename).exists()
  75. def test_delete_all_files(self, monkeypatch, tmp_path):
  76. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  77. (upload_dir / self.filename).write_bytes(self.file_content)
  78. (upload_dir / self.filename_extra).write_bytes(self.file_content)
  79. self.Storage.delete_all_files()
  80. assert not (upload_dir / self.filename).exists()
  81. assert not (upload_dir / self.filename_extra).exists()
  82. @mock_aws
  83. class TestS3StorageProvider:
  84. def __init__(self):
  85. self.Storage = provider.S3StorageProvider()
  86. self.Storage.bucket_name = "my-bucket"
  87. self.s3_client = boto3.resource("s3", region_name="us-east-1")
  88. self.file_content = b"test content"
  89. self.filename = "test.txt"
  90. self.filename_extra = "test_exyta.txt"
  91. self.file_bytesio_empty = io.BytesIO()
  92. super().__init__()
  93. def test_upload_file(self, monkeypatch, tmp_path):
  94. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  95. # S3 checks
  96. with pytest.raises(Exception):
  97. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  98. self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
  99. contents, s3_file_path = self.Storage.upload_file(
  100. io.BytesIO(self.file_content), self.filename
  101. )
  102. object = self.s3_client.Object(self.Storage.bucket_name, self.filename)
  103. assert self.file_content == object.get()["Body"].read()
  104. # local checks
  105. assert (upload_dir / self.filename).exists()
  106. assert (upload_dir / self.filename).read_bytes() == self.file_content
  107. assert contents == self.file_content
  108. assert s3_file_path == "s3://" + self.Storage.bucket_name + "/" + self.filename
  109. with pytest.raises(ValueError):
  110. self.Storage.upload_file(self.file_bytesio_empty, self.filename)
  111. def test_get_file(self, monkeypatch, tmp_path):
  112. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  113. self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
  114. contents, s3_file_path = self.Storage.upload_file(
  115. io.BytesIO(self.file_content), self.filename
  116. )
  117. file_path = self.Storage.get_file(s3_file_path)
  118. assert file_path == str(upload_dir / self.filename)
  119. assert (upload_dir / self.filename).exists()
  120. def test_delete_file(self, monkeypatch, tmp_path):
  121. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  122. self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
  123. contents, s3_file_path = self.Storage.upload_file(
  124. io.BytesIO(self.file_content), self.filename
  125. )
  126. assert (upload_dir / self.filename).exists()
  127. self.Storage.delete_file(s3_file_path)
  128. assert not (upload_dir / self.filename).exists()
  129. with pytest.raises(ClientError) as exc:
  130. self.s3_client.Object(self.Storage.bucket_name, self.filename).load()
  131. error = exc.value.response["Error"]
  132. assert error["Code"] == "404"
  133. assert error["Message"] == "Not Found"
  134. def test_delete_all_files(self, monkeypatch, tmp_path):
  135. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  136. # create 2 files
  137. self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
  138. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  139. object = self.s3_client.Object(self.Storage.bucket_name, self.filename)
  140. assert self.file_content == object.get()["Body"].read()
  141. assert (upload_dir / self.filename).exists()
  142. assert (upload_dir / self.filename).read_bytes() == self.file_content
  143. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename_extra)
  144. object = self.s3_client.Object(self.Storage.bucket_name, self.filename_extra)
  145. assert self.file_content == object.get()["Body"].read()
  146. assert (upload_dir / self.filename).exists()
  147. assert (upload_dir / self.filename).read_bytes() == self.file_content
  148. self.Storage.delete_all_files()
  149. assert not (upload_dir / self.filename).exists()
  150. with pytest.raises(ClientError) as exc:
  151. self.s3_client.Object(self.Storage.bucket_name, self.filename).load()
  152. error = exc.value.response["Error"]
  153. assert error["Code"] == "404"
  154. assert error["Message"] == "Not Found"
  155. assert not (upload_dir / self.filename_extra).exists()
  156. with pytest.raises(ClientError) as exc:
  157. self.s3_client.Object(self.Storage.bucket_name, self.filename_extra).load()
  158. error = exc.value.response["Error"]
  159. assert error["Code"] == "404"
  160. assert error["Message"] == "Not Found"
  161. self.Storage.delete_all_files()
  162. assert not (upload_dir / self.filename).exists()
  163. assert not (upload_dir / self.filename_extra).exists()
  164. def test_init_without_credentials(self, monkeypatch):
  165. """Test that S3StorageProvider can initialize without explicit credentials."""
  166. # Temporarily unset the environment variables
  167. monkeypatch.setattr(provider, "S3_ACCESS_KEY_ID", None)
  168. monkeypatch.setattr(provider, "S3_SECRET_ACCESS_KEY", None)
  169. # Should not raise an exception
  170. storage = provider.S3StorageProvider()
  171. assert storage.s3_client is not None
  172. assert storage.bucket_name == provider.S3_BUCKET_NAME
  173. class TestGCSStorageProvider:
  174. Storage = provider.GCSStorageProvider()
  175. Storage.bucket_name = "my-bucket"
  176. file_content = b"test content"
  177. filename = "test.txt"
  178. filename_extra = "test_exyta.txt"
  179. file_bytesio_empty = io.BytesIO()
  180. @pytest.fixture(scope="class")
  181. def setup(self):
  182. host, port = "localhost", 9023
  183. server = create_server(host, port, in_memory=True)
  184. server.start()
  185. os.environ["STORAGE_EMULATOR_HOST"] = f"http://{host}:{port}"
  186. gcs_client = storage.Client()
  187. bucket = gcs_client.bucket(self.Storage.bucket_name)
  188. bucket.create()
  189. self.Storage.gcs_client, self.Storage.bucket = gcs_client, bucket
  190. yield
  191. bucket.delete(force=True)
  192. server.stop()
  193. def test_upload_file(self, monkeypatch, tmp_path, setup):
  194. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  195. # catch error if bucket does not exist
  196. with pytest.raises(Exception):
  197. self.Storage.bucket = monkeypatch(self.Storage, "bucket", None)
  198. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  199. contents, gcs_file_path = self.Storage.upload_file(
  200. io.BytesIO(self.file_content), self.filename
  201. )
  202. object = self.Storage.bucket.get_blob(self.filename)
  203. assert self.file_content == object.download_as_bytes()
  204. # local checks
  205. assert (upload_dir / self.filename).exists()
  206. assert (upload_dir / self.filename).read_bytes() == self.file_content
  207. assert contents == self.file_content
  208. assert gcs_file_path == "gs://" + self.Storage.bucket_name + "/" + self.filename
  209. # test error if file is empty
  210. with pytest.raises(ValueError):
  211. self.Storage.upload_file(self.file_bytesio_empty, self.filename)
  212. def test_get_file(self, monkeypatch, tmp_path, setup):
  213. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  214. contents, gcs_file_path = self.Storage.upload_file(
  215. io.BytesIO(self.file_content), self.filename
  216. )
  217. file_path = self.Storage.get_file(gcs_file_path)
  218. assert file_path == str(upload_dir / self.filename)
  219. assert (upload_dir / self.filename).exists()
  220. def test_delete_file(self, monkeypatch, tmp_path, setup):
  221. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  222. contents, gcs_file_path = self.Storage.upload_file(
  223. io.BytesIO(self.file_content), self.filename
  224. )
  225. # ensure that local directory has the uploaded file as well
  226. assert (upload_dir / self.filename).exists()
  227. assert self.Storage.bucket.get_blob(self.filename).name == self.filename
  228. self.Storage.delete_file(gcs_file_path)
  229. # check that deleting file from gcs will delete the local file as well
  230. assert not (upload_dir / self.filename).exists()
  231. assert self.Storage.bucket.get_blob(self.filename) == None
  232. def test_delete_all_files(self, monkeypatch, tmp_path, setup):
  233. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  234. # create 2 files
  235. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  236. object = self.Storage.bucket.get_blob(self.filename)
  237. assert (upload_dir / self.filename).exists()
  238. assert (upload_dir / self.filename).read_bytes() == self.file_content
  239. assert self.Storage.bucket.get_blob(self.filename).name == self.filename
  240. assert self.file_content == object.download_as_bytes()
  241. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename_extra)
  242. object = self.Storage.bucket.get_blob(self.filename_extra)
  243. assert (upload_dir / self.filename_extra).exists()
  244. assert (upload_dir / self.filename_extra).read_bytes() == self.file_content
  245. assert (
  246. self.Storage.bucket.get_blob(self.filename_extra).name
  247. == self.filename_extra
  248. )
  249. assert self.file_content == object.download_as_bytes()
  250. self.Storage.delete_all_files()
  251. assert not (upload_dir / self.filename).exists()
  252. assert not (upload_dir / self.filename_extra).exists()
  253. assert self.Storage.bucket.get_blob(self.filename) == None
  254. assert self.Storage.bucket.get_blob(self.filename_extra) == None
  255. class TestAzureStorageProvider:
  256. def __init__(self):
  257. super().__init__()
  258. @pytest.fixture(scope="class")
  259. def setup_storage(self, monkeypatch):
  260. # Create mock Blob Service Client and related clients
  261. mock_blob_service_client = MagicMock()
  262. mock_container_client = MagicMock()
  263. mock_blob_client = MagicMock()
  264. # Set up return values for the mock
  265. mock_blob_service_client.get_container_client.return_value = (
  266. mock_container_client
  267. )
  268. mock_container_client.get_blob_client.return_value = mock_blob_client
  269. # Monkeypatch the Azure classes to return our mocks
  270. monkeypatch.setattr(
  271. azure.storage.blob,
  272. "BlobServiceClient",
  273. lambda *args, **kwargs: mock_blob_service_client,
  274. )
  275. monkeypatch.setattr(
  276. azure.storage.blob,
  277. "ContainerClient",
  278. lambda *args, **kwargs: mock_container_client,
  279. )
  280. monkeypatch.setattr(
  281. azure.storage.blob, "BlobClient", lambda *args, **kwargs: mock_blob_client
  282. )
  283. self.Storage = provider.AzureStorageProvider()
  284. self.Storage.endpoint = "https://myaccount.blob.core.windows.net"
  285. self.Storage.container_name = "my-container"
  286. self.file_content = b"test content"
  287. self.filename = "test.txt"
  288. self.filename_extra = "test_extra.txt"
  289. self.file_bytesio_empty = io.BytesIO()
  290. # Apply mocks to the Storage instance
  291. self.Storage.blob_service_client = mock_blob_service_client
  292. self.Storage.container_client = mock_container_client
  293. def test_upload_file(self, monkeypatch, tmp_path):
  294. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  295. # Simulate an error when container does not exist
  296. self.Storage.container_client.get_blob_client.side_effect = Exception(
  297. "Container does not exist"
  298. )
  299. with pytest.raises(Exception):
  300. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  301. # Reset side effect and create container
  302. self.Storage.container_client.get_blob_client.side_effect = None
  303. self.Storage.create_container()
  304. contents, azure_file_path = self.Storage.upload_file(
  305. io.BytesIO(self.file_content), self.filename
  306. )
  307. # Assertions
  308. self.Storage.container_client.get_blob_client.assert_called_with(self.filename)
  309. self.Storage.container_client.get_blob_client().upload_blob.assert_called_once_with(
  310. self.file_content, overwrite=True
  311. )
  312. assert contents == self.file_content
  313. assert (
  314. azure_file_path
  315. == f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
  316. )
  317. assert (upload_dir / self.filename).exists()
  318. assert (upload_dir / self.filename).read_bytes() == self.file_content
  319. with pytest.raises(ValueError):
  320. self.Storage.upload_file(self.file_bytesio_empty, self.filename)
  321. def test_get_file(self, monkeypatch, tmp_path):
  322. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  323. self.Storage.create_container()
  324. # Mock upload behavior
  325. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  326. # Mock blob download behavior
  327. self.Storage.container_client.get_blob_client().download_blob().readall.return_value = (
  328. self.file_content
  329. )
  330. file_url = f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
  331. file_path = self.Storage.get_file(file_url)
  332. assert file_path == str(upload_dir / self.filename)
  333. assert (upload_dir / self.filename).exists()
  334. assert (upload_dir / self.filename).read_bytes() == self.file_content
  335. def test_delete_file(self, monkeypatch, tmp_path):
  336. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  337. self.Storage.create_container()
  338. # Mock file upload
  339. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  340. # Mock deletion
  341. self.Storage.container_client.get_blob_client().delete_blob.return_value = None
  342. file_url = f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
  343. self.Storage.delete_file(file_url)
  344. self.Storage.container_client.get_blob_client().delete_blob.assert_called_once()
  345. assert not (upload_dir / self.filename).exists()
  346. def test_delete_all_files(self, monkeypatch, tmp_path):
  347. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  348. self.Storage.create_container()
  349. # Mock file uploads
  350. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  351. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename_extra)
  352. # Mock listing and deletion behavior
  353. self.Storage.container_client.list_blobs.return_value = [
  354. {"name": self.filename},
  355. {"name": self.filename_extra},
  356. ]
  357. self.Storage.container_client.get_blob_client().delete_blob.return_value = None
  358. self.Storage.delete_all_files()
  359. self.Storage.container_client.list_blobs.assert_called_once()
  360. self.Storage.container_client.get_blob_client().delete_blob.assert_any_call()
  361. assert not (upload_dir / self.filename).exists()
  362. assert not (upload_dir / self.filename_extra).exists()
  363. def test_get_file_not_found(self, monkeypatch):
  364. self.Storage.create_container()
  365. file_url = f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
  366. # Mock behavior to raise an error for missing blobs
  367. self.Storage.container_client.get_blob_client().download_blob.side_effect = (
  368. Exception("Blob not found")
  369. )
  370. with pytest.raises(Exception, match="Blob not found"):
  371. self.Storage.get_file(file_url)