test_provider.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. import io
  2. import os
  3. import boto3
  4. import pytest
  5. from botocore.exceptions import ClientError
  6. from moto import mock_aws, mock_azure
  7. from open_webui.storage import provider
  8. from gcp_storage_emulator.server import create_server
  9. from google.cloud import storage
  10. from azure.storage.blob import BlobServiceClient, BlobContainerClient, BlobClient
  11. from unittest.mock import MagicMock
  12. def mock_upload_dir(monkeypatch, tmp_path):
  13. """Fixture to monkey-patch the UPLOAD_DIR and create a temporary directory."""
  14. directory = tmp_path / "uploads"
  15. directory.mkdir()
  16. monkeypatch.setattr(provider, "UPLOAD_DIR", str(directory))
  17. return directory
  18. def test_imports():
  19. provider.StorageProvider
  20. provider.LocalStorageProvider
  21. provider.S3StorageProvider
  22. provider.GCSStorageProvider
  23. provider.AzureStorageProvider
  24. provider.Storage
  25. def test_get_storage_provider():
  26. Storage = provider.get_storage_provider("local")
  27. assert isinstance(Storage, provider.LocalStorageProvider)
  28. Storage = provider.get_storage_provider("s3")
  29. assert isinstance(Storage, provider.S3StorageProvider)
  30. Storage = provider.get_storage_provider("gcs")
  31. assert isinstance(Storage, provider.GCSStorageProvider)
  32. Storage = provider.get_storage_provider("azure")
  33. assert isinstance(Storage, provider.AzureStorageProvider)
  34. with pytest.raises(RuntimeError):
  35. provider.get_storage_provider("invalid")
  36. def test_class_instantiation():
  37. with pytest.raises(TypeError):
  38. provider.StorageProvider()
  39. with pytest.raises(TypeError):
  40. class Test(provider.StorageProvider):
  41. pass
  42. Test()
  43. provider.LocalStorageProvider()
  44. provider.S3StorageProvider()
  45. provider.GCSStorageProvider()
  46. provider.AzureStorageProvider()
  47. class TestLocalStorageProvider:
  48. Storage = provider.LocalStorageProvider()
  49. file_content = b"test content"
  50. file_bytesio = io.BytesIO(file_content)
  51. filename = "test.txt"
  52. filename_extra = "test_exyta.txt"
  53. file_bytesio_empty = io.BytesIO()
  54. def test_upload_file(self, monkeypatch, tmp_path):
  55. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  56. contents, file_path = self.Storage.upload_file(self.file_bytesio, self.filename)
  57. assert (upload_dir / self.filename).exists()
  58. assert (upload_dir / self.filename).read_bytes() == self.file_content
  59. assert contents == self.file_content
  60. assert file_path == str(upload_dir / self.filename)
  61. with pytest.raises(ValueError):
  62. self.Storage.upload_file(self.file_bytesio_empty, self.filename)
  63. def test_get_file(self, monkeypatch, tmp_path):
  64. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  65. file_path = str(upload_dir / self.filename)
  66. file_path_return = self.Storage.get_file(file_path)
  67. assert file_path == file_path_return
  68. def test_delete_file(self, monkeypatch, tmp_path):
  69. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  70. (upload_dir / self.filename).write_bytes(self.file_content)
  71. assert (upload_dir / self.filename).exists()
  72. file_path = str(upload_dir / self.filename)
  73. self.Storage.delete_file(file_path)
  74. assert not (upload_dir / self.filename).exists()
  75. def test_delete_all_files(self, monkeypatch, tmp_path):
  76. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  77. (upload_dir / self.filename).write_bytes(self.file_content)
  78. (upload_dir / self.filename_extra).write_bytes(self.file_content)
  79. self.Storage.delete_all_files()
  80. assert not (upload_dir / self.filename).exists()
  81. assert not (upload_dir / self.filename_extra).exists()
  82. @mock_aws
  83. class TestS3StorageProvider:
  84. def __init__(self):
  85. self.Storage = provider.S3StorageProvider()
  86. self.Storage.bucket_name = "my-bucket"
  87. self.s3_client = boto3.resource("s3", region_name="us-east-1")
  88. self.file_content = b"test content"
  89. self.filename = "test.txt"
  90. self.filename_extra = "test_exyta.txt"
  91. self.file_bytesio_empty = io.BytesIO()
  92. super().__init__()
  93. def test_upload_file(self, monkeypatch, tmp_path):
  94. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  95. # S3 checks
  96. with pytest.raises(Exception):
  97. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  98. self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
  99. contents, s3_file_path = self.Storage.upload_file(
  100. io.BytesIO(self.file_content), self.filename
  101. )
  102. object = self.s3_client.Object(self.Storage.bucket_name, self.filename)
  103. assert self.file_content == object.get()["Body"].read()
  104. # local checks
  105. assert (upload_dir / self.filename).exists()
  106. assert (upload_dir / self.filename).read_bytes() == self.file_content
  107. assert contents == self.file_content
  108. assert s3_file_path == "s3://" + self.Storage.bucket_name + "/" + self.filename
  109. with pytest.raises(ValueError):
  110. self.Storage.upload_file(self.file_bytesio_empty, self.filename)
  111. def test_get_file(self, monkeypatch, tmp_path):
  112. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  113. self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
  114. contents, s3_file_path = self.Storage.upload_file(
  115. io.BytesIO(self.file_content), self.filename
  116. )
  117. file_path = self.Storage.get_file(s3_file_path)
  118. assert file_path == str(upload_dir / self.filename)
  119. assert (upload_dir / self.filename).exists()
  120. def test_delete_file(self, monkeypatch, tmp_path):
  121. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  122. self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
  123. contents, s3_file_path = self.Storage.upload_file(
  124. io.BytesIO(self.file_content), self.filename
  125. )
  126. assert (upload_dir / self.filename).exists()
  127. self.Storage.delete_file(s3_file_path)
  128. assert not (upload_dir / self.filename).exists()
  129. with pytest.raises(ClientError) as exc:
  130. self.s3_client.Object(self.Storage.bucket_name, self.filename).load()
  131. error = exc.value.response["Error"]
  132. assert error["Code"] == "404"
  133. assert error["Message"] == "Not Found"
  134. def test_delete_all_files(self, monkeypatch, tmp_path):
  135. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  136. # create 2 files
  137. self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
  138. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  139. object = self.s3_client.Object(self.Storage.bucket_name, self.filename)
  140. assert self.file_content == object.get()["Body"].read()
  141. assert (upload_dir / self.filename).exists()
  142. assert (upload_dir / self.filename).read_bytes() == self.file_content
  143. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename_extra)
  144. object = self.s3_client.Object(self.Storage.bucket_name, self.filename_extra)
  145. assert self.file_content == object.get()["Body"].read()
  146. assert (upload_dir / self.filename).exists()
  147. assert (upload_dir / self.filename).read_bytes() == self.file_content
  148. self.Storage.delete_all_files()
  149. assert not (upload_dir / self.filename).exists()
  150. with pytest.raises(ClientError) as exc:
  151. self.s3_client.Object(self.Storage.bucket_name, self.filename).load()
  152. error = exc.value.response["Error"]
  153. assert error["Code"] == "404"
  154. assert error["Message"] == "Not Found"
  155. assert not (upload_dir / self.filename_extra).exists()
  156. with pytest.raises(ClientError) as exc:
  157. self.s3_client.Object(self.Storage.bucket_name, self.filename_extra).load()
  158. error = exc.value.response["Error"]
  159. assert error["Code"] == "404"
  160. assert error["Message"] == "Not Found"
  161. self.Storage.delete_all_files()
  162. assert not (upload_dir / self.filename).exists()
  163. assert not (upload_dir / self.filename_extra).exists()
  164. class TestGCSStorageProvider:
  165. Storage = provider.GCSStorageProvider()
  166. Storage.bucket_name = "my-bucket"
  167. file_content = b"test content"
  168. filename = "test.txt"
  169. filename_extra = "test_exyta.txt"
  170. file_bytesio_empty = io.BytesIO()
  171. @pytest.fixture(scope="class")
  172. def setup(self):
  173. host, port = "localhost", 9023
  174. server = create_server(host, port, in_memory=True)
  175. server.start()
  176. os.environ["STORAGE_EMULATOR_HOST"] = f"http://{host}:{port}"
  177. gcs_client = storage.Client()
  178. bucket = gcs_client.bucket(self.Storage.bucket_name)
  179. bucket.create()
  180. self.Storage.gcs_client, self.Storage.bucket = gcs_client, bucket
  181. yield
  182. bucket.delete(force=True)
  183. server.stop()
  184. def test_upload_file(self, monkeypatch, tmp_path, setup):
  185. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  186. # catch error if bucket does not exist
  187. with pytest.raises(Exception):
  188. self.Storage.bucket = monkeypatch(self.Storage, "bucket", None)
  189. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  190. contents, gcs_file_path = self.Storage.upload_file(
  191. io.BytesIO(self.file_content), self.filename
  192. )
  193. object = self.Storage.bucket.get_blob(self.filename)
  194. assert self.file_content == object.download_as_bytes()
  195. # local checks
  196. assert (upload_dir / self.filename).exists()
  197. assert (upload_dir / self.filename).read_bytes() == self.file_content
  198. assert contents == self.file_content
  199. assert gcs_file_path == "gs://" + self.Storage.bucket_name + "/" + self.filename
  200. # test error if file is empty
  201. with pytest.raises(ValueError):
  202. self.Storage.upload_file(self.file_bytesio_empty, self.filename)
  203. def test_get_file(self, monkeypatch, tmp_path, setup):
  204. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  205. contents, gcs_file_path = self.Storage.upload_file(
  206. io.BytesIO(self.file_content), self.filename
  207. )
  208. file_path = self.Storage.get_file(gcs_file_path)
  209. assert file_path == str(upload_dir / self.filename)
  210. assert (upload_dir / self.filename).exists()
  211. def test_delete_file(self, monkeypatch, tmp_path, setup):
  212. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  213. contents, gcs_file_path = self.Storage.upload_file(
  214. io.BytesIO(self.file_content), self.filename
  215. )
  216. # ensure that local directory has the uploaded file as well
  217. assert (upload_dir / self.filename).exists()
  218. assert self.Storage.bucket.get_blob(self.filename).name == self.filename
  219. self.Storage.delete_file(gcs_file_path)
  220. # check that deleting file from gcs will delete the local file as well
  221. assert not (upload_dir / self.filename).exists()
  222. assert self.Storage.bucket.get_blob(self.filename) == None
  223. def test_delete_all_files(self, monkeypatch, tmp_path, setup):
  224. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  225. # create 2 files
  226. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  227. object = self.Storage.bucket.get_blob(self.filename)
  228. assert (upload_dir / self.filename).exists()
  229. assert (upload_dir / self.filename).read_bytes() == self.file_content
  230. assert self.Storage.bucket.get_blob(self.filename).name == self.filename
  231. assert self.file_content == object.download_as_bytes()
  232. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename_extra)
  233. object = self.Storage.bucket.get_blob(self.filename_extra)
  234. assert (upload_dir / self.filename_extra).exists()
  235. assert (upload_dir / self.filename_extra).read_bytes() == self.file_content
  236. assert (
  237. self.Storage.bucket.get_blob(self.filename_extra).name
  238. == self.filename_extra
  239. )
  240. assert self.file_content == object.download_as_bytes()
  241. self.Storage.delete_all_files()
  242. assert not (upload_dir / self.filename).exists()
  243. assert not (upload_dir / self.filename_extra).exists()
  244. assert self.Storage.bucket.get_blob(self.filename) == None
  245. assert self.Storage.bucket.get_blob(self.filename_extra) == None
  246. class TestAzureStorageProvider:
  247. def __init__(self):
  248. self.Storage = provider.AzureStorageProvider()
  249. self.Storage.endpoint = "https://myaccount.blob.core.windows.net"
  250. self.Storage.container_name = "my-container"
  251. self.file_content = b"test content"
  252. self.filename = "test.txt"
  253. self.filename_extra = "test_extra.txt"
  254. self.file_bytesio_empty = io.BytesIO()
  255. super().__init__()
  256. @pytest.fixture
  257. def setup(self, monkeypatch):
  258. """Mock BlobServiceClient and BlobContainerClient for local testing"""
  259. # Create mock Blob Service Client
  260. mock_blob_service_client = MagicMock()
  261. mock_container_client = MagicMock()
  262. mock_blob_client = MagicMock()
  263. # Set up return values
  264. mock_blob_service_client.get_container_client.return_value = mock_container_client
  265. mock_container_client.get_blob_client.return_value = mock_blob_client
  266. # Mock `from_connection_string` and `BlobServiceClient` constructor
  267. monkeypatch.setattr("azure.storage.blob.BlobServiceClient", lambda *_: mock_blob_service_client)
  268. # Apply to instance variables
  269. self.Storage.blob_service_client = mock_blob_service_client
  270. self.Storage.container_client = mock_container_client
  271. yield
  272. def test_upload_file(self, monkeypatch, tmp_path, setup):
  273. """Test uploading a file to mocked Azure Storage."""
  274. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  275. # Simulate an error when container does not exist
  276. self.Storage.container_client.get_blob_client.side_effect = Exception("Container does not exist")
  277. with pytest.raises(Exception):
  278. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  279. # Reset side effect and create container
  280. self.Storage.container_client.get_blob_client.side_effect = None
  281. self.Storage.create_container()
  282. contents, azure_file_path = self.Storage.upload_file(
  283. io.BytesIO(self.file_content), self.filename
  284. )
  285. # Assertions
  286. self.Storage.container_client.get_blob_client.assert_called_with(self.filename)
  287. self.Storage.container_client.get_blob_client().upload_blob.assert_called_once_with(self.file_content, overwrite=True)
  288. assert contents == self.file_content
  289. assert azure_file_path == f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
  290. assert (upload_dir / self.filename).exists()
  291. assert (upload_dir / self.filename).read_bytes() == self.file_content
  292. with pytest.raises(ValueError):
  293. self.Storage.upload_file(self.file_bytesio_empty, self.filename)
  294. def test_get_file(self, monkeypatch, tmp_path, setup):
  295. """Test retrieving a file from mocked Azure Storage."""
  296. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  297. self.Storage.create_container()
  298. # Mock upload behavior
  299. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  300. # Mock blob download behavior
  301. self.Storage.container_client.get_blob_client().download_blob().readall.return_value = self.file_content
  302. file_path = self.Storage.get_file(f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}")
  303. assert file_path == str(upload_dir / self.filename)
  304. assert (upload_dir / self.filename).exists()
  305. assert (upload_dir / self.filename).read_bytes() == self.file_content
  306. def test_delete_file(self, monkeypatch, tmp_path, setup):
  307. """Test deleting a file from mocked Azure Storage."""
  308. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  309. self.Storage.create_container()
  310. # Mock upload
  311. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  312. # Mock deletion
  313. self.Storage.container_client.get_blob_client().delete_blob.return_value = None
  314. self.Storage.delete_file(f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}")
  315. # Assertions
  316. self.Storage.container_client.get_blob_client().delete_blob.assert_called_once()
  317. assert not (upload_dir / self.filename).exists()
  318. def test_delete_all_files(self, monkeypatch, tmp_path, setup):
  319. """Test deleting all files from mocked Azure Storage."""
  320. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  321. self.Storage.create_container()
  322. # Mock file uploads
  323. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  324. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename_extra)
  325. # Mock listing and deletion behavior
  326. self.Storage.container_client.list_blobs.return_value = [
  327. {"name": self.filename},
  328. {"name": self.filename_extra},
  329. ]
  330. self.Storage.container_client.get_blob_client().delete_blob.return_value = None
  331. # Call delete all files
  332. self.Storage.delete_all_files()
  333. # Assertions
  334. self.Storage.container_client.list_blobs.assert_called_once()
  335. self.Storage.container_client.get_blob_client().delete_blob.assert_any_call()
  336. assert not (upload_dir / self.filename).exists()
  337. assert not (upload_dir / self.filename_extra).exists()
  338. def test_get_file_not_found(self, monkeypatch, setup):
  339. """Test handling when a requested file does not exist."""
  340. self.Storage.create_container()
  341. # Mock behavior to raise an error for missing files
  342. self.Storage.container_client.get_blob_client().download_blob.side_effect = Exception("Blob not found")
  343. with pytest.raises(Exception, match="Blob not found"):
  344. self.Storage.get_file(f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}")