test_provider.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. import io
  2. import os
  3. import boto3
  4. import pytest
  5. from botocore.exceptions import ClientError
  6. from moto import mock_aws, mock_azure
  7. from open_webui.storage import provider
  8. from gcp_storage_emulator.server import create_server
  9. from google.cloud import storage
  10. from azure.storage.blob import BlobServiceClient
  11. def mock_upload_dir(monkeypatch, tmp_path):
  12. """Fixture to monkey-patch the UPLOAD_DIR and create a temporary directory."""
  13. directory = tmp_path / "uploads"
  14. directory.mkdir()
  15. monkeypatch.setattr(provider, "UPLOAD_DIR", str(directory))
  16. return directory
  17. def test_imports():
  18. provider.StorageProvider
  19. provider.LocalStorageProvider
  20. provider.S3StorageProvider
  21. provider.GCSStorageProvider
  22. provider.AzureStorageProvider
  23. provider.Storage
  24. def test_get_storage_provider():
  25. Storage = provider.get_storage_provider("local")
  26. assert isinstance(Storage, provider.LocalStorageProvider)
  27. Storage = provider.get_storage_provider("s3")
  28. assert isinstance(Storage, provider.S3StorageProvider)
  29. Storage = provider.get_storage_provider("gcs")
  30. assert isinstance(Storage, provider.GCSStorageProvider)
  31. Storage = provider.get_storage_provider("azure")
  32. assert isinstance(Storage, provider.AzureStorageProvider)
  33. with pytest.raises(RuntimeError):
  34. provider.get_storage_provider("invalid")
  35. def test_class_instantiation():
  36. with pytest.raises(TypeError):
  37. provider.StorageProvider()
  38. with pytest.raises(TypeError):
  39. class Test(provider.StorageProvider):
  40. pass
  41. Test()
  42. provider.LocalStorageProvider()
  43. provider.S3StorageProvider()
  44. provider.GCSStorageProvider()
  45. provider.AzureStorageProvider()
  46. class TestLocalStorageProvider:
  47. Storage = provider.LocalStorageProvider()
  48. file_content = b"test content"
  49. file_bytesio = io.BytesIO(file_content)
  50. filename = "test.txt"
  51. filename_extra = "test_exyta.txt"
  52. file_bytesio_empty = io.BytesIO()
  53. def test_upload_file(self, monkeypatch, tmp_path):
  54. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  55. contents, file_path = self.Storage.upload_file(self.file_bytesio, self.filename)
  56. assert (upload_dir / self.filename).exists()
  57. assert (upload_dir / self.filename).read_bytes() == self.file_content
  58. assert contents == self.file_content
  59. assert file_path == str(upload_dir / self.filename)
  60. with pytest.raises(ValueError):
  61. self.Storage.upload_file(self.file_bytesio_empty, self.filename)
  62. def test_get_file(self, monkeypatch, tmp_path):
  63. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  64. file_path = str(upload_dir / self.filename)
  65. file_path_return = self.Storage.get_file(file_path)
  66. assert file_path == file_path_return
  67. def test_delete_file(self, monkeypatch, tmp_path):
  68. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  69. (upload_dir / self.filename).write_bytes(self.file_content)
  70. assert (upload_dir / self.filename).exists()
  71. file_path = str(upload_dir / self.filename)
  72. self.Storage.delete_file(file_path)
  73. assert not (upload_dir / self.filename).exists()
  74. def test_delete_all_files(self, monkeypatch, tmp_path):
  75. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  76. (upload_dir / self.filename).write_bytes(self.file_content)
  77. (upload_dir / self.filename_extra).write_bytes(self.file_content)
  78. self.Storage.delete_all_files()
  79. assert not (upload_dir / self.filename).exists()
  80. assert not (upload_dir / self.filename_extra).exists()
  81. @mock_aws
  82. class TestS3StorageProvider:
  83. def __init__(self):
  84. self.Storage = provider.S3StorageProvider()
  85. self.Storage.bucket_name = "my-bucket"
  86. self.s3_client = boto3.resource("s3", region_name="us-east-1")
  87. self.file_content = b"test content"
  88. self.filename = "test.txt"
  89. self.filename_extra = "test_exyta.txt"
  90. self.file_bytesio_empty = io.BytesIO()
  91. super().__init__()
  92. def test_upload_file(self, monkeypatch, tmp_path):
  93. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  94. # S3 checks
  95. with pytest.raises(Exception):
  96. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  97. self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
  98. contents, s3_file_path = self.Storage.upload_file(
  99. io.BytesIO(self.file_content), self.filename
  100. )
  101. object = self.s3_client.Object(self.Storage.bucket_name, self.filename)
  102. assert self.file_content == object.get()["Body"].read()
  103. # local checks
  104. assert (upload_dir / self.filename).exists()
  105. assert (upload_dir / self.filename).read_bytes() == self.file_content
  106. assert contents == self.file_content
  107. assert s3_file_path == "s3://" + self.Storage.bucket_name + "/" + self.filename
  108. with pytest.raises(ValueError):
  109. self.Storage.upload_file(self.file_bytesio_empty, self.filename)
  110. def test_get_file(self, monkeypatch, tmp_path):
  111. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  112. self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
  113. contents, s3_file_path = self.Storage.upload_file(
  114. io.BytesIO(self.file_content), self.filename
  115. )
  116. file_path = self.Storage.get_file(s3_file_path)
  117. assert file_path == str(upload_dir / self.filename)
  118. assert (upload_dir / self.filename).exists()
  119. def test_delete_file(self, monkeypatch, tmp_path):
  120. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  121. self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
  122. contents, s3_file_path = self.Storage.upload_file(
  123. io.BytesIO(self.file_content), self.filename
  124. )
  125. assert (upload_dir / self.filename).exists()
  126. self.Storage.delete_file(s3_file_path)
  127. assert not (upload_dir / self.filename).exists()
  128. with pytest.raises(ClientError) as exc:
  129. self.s3_client.Object(self.Storage.bucket_name, self.filename).load()
  130. error = exc.value.response["Error"]
  131. assert error["Code"] == "404"
  132. assert error["Message"] == "Not Found"
  133. def test_delete_all_files(self, monkeypatch, tmp_path):
  134. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  135. # create 2 files
  136. self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
  137. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  138. object = self.s3_client.Object(self.Storage.bucket_name, self.filename)
  139. assert self.file_content == object.get()["Body"].read()
  140. assert (upload_dir / self.filename).exists()
  141. assert (upload_dir / self.filename).read_bytes() == self.file_content
  142. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename_extra)
  143. object = self.s3_client.Object(self.Storage.bucket_name, self.filename_extra)
  144. assert self.file_content == object.get()["Body"].read()
  145. assert (upload_dir / self.filename).exists()
  146. assert (upload_dir / self.filename).read_bytes() == self.file_content
  147. self.Storage.delete_all_files()
  148. assert not (upload_dir / self.filename).exists()
  149. with pytest.raises(ClientError) as exc:
  150. self.s3_client.Object(self.Storage.bucket_name, self.filename).load()
  151. error = exc.value.response["Error"]
  152. assert error["Code"] == "404"
  153. assert error["Message"] == "Not Found"
  154. assert not (upload_dir / self.filename_extra).exists()
  155. with pytest.raises(ClientError) as exc:
  156. self.s3_client.Object(self.Storage.bucket_name, self.filename_extra).load()
  157. error = exc.value.response["Error"]
  158. assert error["Code"] == "404"
  159. assert error["Message"] == "Not Found"
  160. self.Storage.delete_all_files()
  161. assert not (upload_dir / self.filename).exists()
  162. assert not (upload_dir / self.filename_extra).exists()
  163. class TestGCSStorageProvider:
  164. Storage = provider.GCSStorageProvider()
  165. Storage.bucket_name = "my-bucket"
  166. file_content = b"test content"
  167. filename = "test.txt"
  168. filename_extra = "test_exyta.txt"
  169. file_bytesio_empty = io.BytesIO()
  170. @pytest.fixture(scope="class")
  171. def setup(self):
  172. host, port = "localhost", 9023
  173. server = create_server(host, port, in_memory=True)
  174. server.start()
  175. os.environ["STORAGE_EMULATOR_HOST"] = f"http://{host}:{port}"
  176. gcs_client = storage.Client()
  177. bucket = gcs_client.bucket(self.Storage.bucket_name)
  178. bucket.create()
  179. self.Storage.gcs_client, self.Storage.bucket = gcs_client, bucket
  180. yield
  181. bucket.delete(force=True)
  182. server.stop()
  183. def test_upload_file(self, monkeypatch, tmp_path, setup):
  184. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  185. # catch error if bucket does not exist
  186. with pytest.raises(Exception):
  187. self.Storage.bucket = monkeypatch(self.Storage, "bucket", None)
  188. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  189. contents, gcs_file_path = self.Storage.upload_file(
  190. io.BytesIO(self.file_content), self.filename
  191. )
  192. object = self.Storage.bucket.get_blob(self.filename)
  193. assert self.file_content == object.download_as_bytes()
  194. # local checks
  195. assert (upload_dir / self.filename).exists()
  196. assert (upload_dir / self.filename).read_bytes() == self.file_content
  197. assert contents == self.file_content
  198. assert gcs_file_path == "gs://" + self.Storage.bucket_name + "/" + self.filename
  199. # test error if file is empty
  200. with pytest.raises(ValueError):
  201. self.Storage.upload_file(self.file_bytesio_empty, self.filename)
  202. def test_get_file(self, monkeypatch, tmp_path, setup):
  203. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  204. contents, gcs_file_path = self.Storage.upload_file(
  205. io.BytesIO(self.file_content), self.filename
  206. )
  207. file_path = self.Storage.get_file(gcs_file_path)
  208. assert file_path == str(upload_dir / self.filename)
  209. assert (upload_dir / self.filename).exists()
  210. def test_delete_file(self, monkeypatch, tmp_path, setup):
  211. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  212. contents, gcs_file_path = self.Storage.upload_file(
  213. io.BytesIO(self.file_content), self.filename
  214. )
  215. # ensure that local directory has the uploaded file as well
  216. assert (upload_dir / self.filename).exists()
  217. assert self.Storage.bucket.get_blob(self.filename).name == self.filename
  218. self.Storage.delete_file(gcs_file_path)
  219. # check that deleting file from gcs will delete the local file as well
  220. assert not (upload_dir / self.filename).exists()
  221. assert self.Storage.bucket.get_blob(self.filename) == None
  222. def test_delete_all_files(self, monkeypatch, tmp_path, setup):
  223. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  224. # create 2 files
  225. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  226. object = self.Storage.bucket.get_blob(self.filename)
  227. assert (upload_dir / self.filename).exists()
  228. assert (upload_dir / self.filename).read_bytes() == self.file_content
  229. assert self.Storage.bucket.get_blob(self.filename).name == self.filename
  230. assert self.file_content == object.download_as_bytes()
  231. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename_extra)
  232. object = self.Storage.bucket.get_blob(self.filename_extra)
  233. assert (upload_dir / self.filename_extra).exists()
  234. assert (upload_dir / self.filename_extra).read_bytes() == self.file_content
  235. assert (
  236. self.Storage.bucket.get_blob(self.filename_extra).name
  237. == self.filename_extra
  238. )
  239. assert self.file_content == object.download_as_bytes()
  240. self.Storage.delete_all_files()
  241. assert not (upload_dir / self.filename).exists()
  242. assert not (upload_dir / self.filename_extra).exists()
  243. assert self.Storage.bucket.get_blob(self.filename) == None
  244. assert self.Storage.bucket.get_blob(self.filename_extra) == None
  245. class TestAzureStorageProvider:
  246. def __init__(self):
  247. self.Storage = provider.AzureStorageProvider()
  248. self.Storage.container_name = "my-container"
  249. self.file_content = b"test content"
  250. self.filename = "test.txt"
  251. self.filename_extra = "test_exyta.txt"
  252. self.file_bytesio_empty = io.BytesIO()
  253. super().__init__()
  254. @pytest.fixture(scope="class")
  255. def setup(self, monkeypatch):
  256. connection_string = "DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtl6rE4rWlgEoMF1rA==;BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1;"
  257. self.Storage.blob_service_client = BlobServiceClient.from_connection_string(connection_string)
  258. self.Storage.container_client = self.Storage.blob_service_client.get_container_client(self.Storage.container_name)
  259. monkeypatch.setattr(self.Storage, "blob_service_client", self.Storage.blob_service_client)
  260. monkeypatch.setattr(self.Storage, "container_client", self.Storage.container_client)
  261. yield
  262. self.Storage.container_client.delete_container()
  263. def test_upload_file(self, monkeypatch, tmp_path, setup):
  264. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  265. # Azure checks
  266. with pytest.raises(Exception):
  267. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  268. self.Storage.create_container()
  269. contents, azure_file_path = self.Storage.upload_file(
  270. io.BytesIO(self.file_content), self.filename
  271. )
  272. blob = self.Storage.blob_service_client.get_blob_client(
  273. container=self.Storage.container_name, blob=self.filename
  274. )
  275. assert self.file_content == blob.download_blob().readall()
  276. # local checks
  277. assert (upload_dir / self.filename).exists()
  278. assert (upload_dir / self.filename).read_bytes() == self.file_content
  279. assert contents == self.file_content
  280. assert azure_file_path == "azure://" + self.Storage.container_name + "/" + self.filename
  281. with pytest.raises(ValueError):
  282. self.Storage.upload_file(self.file_bytesio_empty, self.filename)
  283. def test_get_file(self, monkeypatch, tmp_path, setup):
  284. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  285. self.Storage.create_container()
  286. contents, azure_file_path = self.Storage.upload_file(
  287. io.BytesIO(self.file_content), self.filename
  288. )
  289. file_path = self.Storage.get_file(azure_file_path)
  290. assert file_path == str(upload_dir / self.filename)
  291. assert (upload_dir / self.filename).exists()
  292. def test_delete_file(self, monkeypatch, tmp_path, setup):
  293. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  294. self.Storage.create_container()
  295. contents, azure_file_path = self.Storage.upload_file(
  296. io.BytesIO(self.file_content), self.filename
  297. )
  298. assert (upload_dir / self.filename).exists()
  299. self.Storage.delete_file(azure_file_path)
  300. assert not (upload_dir / self.filename).exists()
  301. blob = self.Storage.blob_service_client.get_blob_client(
  302. container=self.Storage.container_name, blob=self.filename
  303. )
  304. with pytest.raises(Exception):
  305. blob.download_blob().readall()
  306. def test_delete_all_files(self, monkeypatch, tmp_path, setup):
  307. upload_dir = mock_upload_dir(monkeypatch, tmp_path)
  308. self.Storage.create_container()
  309. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
  310. blob = self.Storage.blob_service_client.get_blob_client(
  311. container=self.Storage.container_name, blob=self.filename
  312. )
  313. assert self.file_content == blob.download_blob().readall()
  314. assert (upload_dir / self.filename).exists()
  315. self.Storage.upload_file(io.BytesIO(self.file_content), self.filename_extra)
  316. blob = self.Storage.blob_service_client.get_blob_client(
  317. container=self.Storage.container_name, blob=self.filename_extra
  318. )
  319. assert self.file_content == blob.download_blob().readall()
  320. assert (upload_dir / self.filename_extra).exists()
  321. self.Storage.delete_all_files()
  322. assert not (upload_dir / self.filename).exists()
  323. assert not (upload_dir / self.filename_extra).exists()
  324. blob = self.Storage.blob_service_client.get_blob_client(
  325. container=self.Storage.container_name, blob=self.filename
  326. )
  327. with pytest.raises(Exception):
  328. blob.download_blob().readall()
  329. blob = self.Storage.blob_service_client.get_blob_client(
  330. container=self.Storage.container_name, blob=self.filename_extra
  331. )
  332. with pytest.raises(Exception):
  333. blob.download_blob().readall()