浏览代码

Add test for S3 using moto

Rodrigo Agundez 3 月之前
父节点
当前提交
021c1f3900
共有 1 个文件被更改,包括 83 次插入19 次删除
  1. 83 19
      backend/open_webui/test/apps/webui/storage/test_provider.py

+ 83 - 19
backend/open_webui/test/apps/webui/storage/test_provider.py

@@ -1,6 +1,9 @@
 import io
 
+import boto3
 import pytest
+from botocore.exceptions import ClientError
+from moto import mock_aws
 from open_webui.storage import provider
 
 
@@ -41,7 +44,8 @@ def test_class_instantiation():
     provider.S3StorageProvider()
 
 
-class TestLocalStorageProvider(provider.LocalStorageProvider):
+class TestLocalStorageProvider:
+    Storage = provider.LocalStorageProvider()
     file_content = b"test content"
     file_bytesio = io.BytesIO(file_content)
     filename = "test.txt"
@@ -50,18 +54,18 @@ class TestLocalStorageProvider(provider.LocalStorageProvider):
 
     def test_upload_file(self, monkeypatch, tmp_path):
         upload_dir = mock_upload_dir(monkeypatch, tmp_path)
-        contents, file_path = self.upload_file(self.file_bytesio, self.filename)
+        contents, file_path = self.Storage.upload_file(self.file_bytesio, self.filename)
         assert (upload_dir / self.filename).exists()
         assert (upload_dir / self.filename).read_bytes() == self.file_content
         assert contents == self.file_content
         assert file_path == str(upload_dir / self.filename)
         with pytest.raises(ValueError):
-            self.upload_file(self.file_bytesio_empty, self.filename)
+            self.Storage.upload_file(self.file_bytesio_empty, self.filename)
 
     def test_get_file(self, monkeypatch, tmp_path):
         upload_dir = mock_upload_dir(monkeypatch, tmp_path)
         file_path = str(upload_dir / self.filename)
-        file_path_return = self.get_file(file_path)
+        file_path_return = self.Storage.get_file(file_path)
         assert file_path == file_path_return
 
     def test_delete_file(self, monkeypatch, tmp_path):
@@ -69,41 +73,101 @@ class TestLocalStorageProvider(provider.LocalStorageProvider):
         (upload_dir / self.filename).write_bytes(self.file_content)
         assert (upload_dir / self.filename).exists()
         file_path = str(upload_dir / self.filename)
-        self.delete_file(file_path)
+        self.Storage.delete_file(file_path)
         assert not (upload_dir / self.filename).exists()
 
     def test_delete_all_files(self, monkeypatch, tmp_path):
         upload_dir = mock_upload_dir(monkeypatch, tmp_path)
         (upload_dir / self.filename).write_bytes(self.file_content)
         (upload_dir / self.filename_extra).write_bytes(self.file_content)
-        self.delete_all_files()
+        self.Storage.delete_all_files()
         assert not (upload_dir / self.filename).exists()
         assert not (upload_dir / self.filename_extra).exists()
 
 
-class TestS3StorageProvider(provider.S3StorageProvider):
+@mock_aws
+class TestS3StorageProvider:
+    Storage = provider.S3StorageProvider()
+    Storage.bucket_name = "my-bucket"
+    s3_client = boto3.resource("s3", region_name="us-east-1")
     file_content = b"test content"
-    file_bytesio = io.BytesIO(file_content)
     filename = "test.txt"
-    filename_extra = "test_extra.txt"
+    filename_extra = "test_exyta.txt"
     file_bytesio_empty = io.BytesIO()
-    bucket_name = "my-bucket"
 
     def test_upload_file(self, monkeypatch, tmp_path):
         upload_dir = mock_upload_dir(monkeypatch, tmp_path)
-        contents, file_path = self.upload_file(self.file_bytesio, self.filename)
+        # S3 checks
+        with pytest.raises(Exception):
+            self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
+        self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
+        contents, s3_file_path = self.Storage.upload_file(
+            io.BytesIO(self.file_content), self.filename
+        )
+        object = self.s3_client.Object(self.Storage.bucket_name, self.filename)
+        assert self.file_content == object.get()["Body"].read()
+        # local checks
         assert (upload_dir / self.filename).exists()
         assert (upload_dir / self.filename).read_bytes() == self.file_content
         assert contents == self.file_content
-        assert file_path == str(upload_dir / self.filename)
+        assert s3_file_path == "s3://" + self.Storage.bucket_name + "/" + self.filename
         with pytest.raises(ValueError):
-            self.upload_file(self.file_bytesio_empty, self.filename)
+            self.Storage.upload_file(self.file_bytesio_empty, self.filename)
 
-    def test_get_file(self):
-        pass
+    def test_get_file(self, monkeypatch, tmp_path):
+        upload_dir = mock_upload_dir(monkeypatch, tmp_path)
+        self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
+        contents, s3_file_path = self.Storage.upload_file(
+            io.BytesIO(self.file_content), self.filename
+        )
+        file_path = self.Storage.get_file(s3_file_path)
+        assert file_path == str(upload_dir / self.filename)
+        assert (upload_dir / self.filename).exists()
 
-    def test_delete_file(self):
-        pass
+    def test_delete_file(self, monkeypatch, tmp_path):
+        upload_dir = mock_upload_dir(monkeypatch, tmp_path)
+        self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
+        contents, s3_file_path = self.Storage.upload_file(
+            io.BytesIO(self.file_content), self.filename
+        )
+        assert (upload_dir / self.filename).exists()
+        self.Storage.delete_file(s3_file_path)
+        assert not (upload_dir / self.filename).exists()
+        with pytest.raises(ClientError) as exc:
+            self.s3_client.Object(self.Storage.bucket_name, self.filename).load()
+        error = exc.value.response["Error"]
+        assert error["Code"] == "404"
+        assert error["Message"] == "Not Found"
 
-    def test_delete_all_files(self):
-        pass
+    def test_delete_all_files(self, monkeypatch, tmp_path):
+        upload_dir = mock_upload_dir(monkeypatch, tmp_path)
+        # create 2 files
+        self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
+        self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
+        object = self.s3_client.Object(self.Storage.bucket_name, self.filename)
+        assert self.file_content == object.get()["Body"].read()
+        assert (upload_dir / self.filename).exists()
+        assert (upload_dir / self.filename).read_bytes() == self.file_content
+        self.Storage.upload_file(io.BytesIO(self.file_content), self.filename_extra)
+        object = self.s3_client.Object(self.Storage.bucket_name, self.filename_extra)
+        assert self.file_content == object.get()["Body"].read()
+        assert (upload_dir / self.filename).exists()
+        assert (upload_dir / self.filename).read_bytes() == self.file_content
+
+        self.Storage.delete_all_files()
+        assert not (upload_dir / self.filename).exists()
+        with pytest.raises(ClientError) as exc:
+            self.s3_client.Object(self.Storage.bucket_name, self.filename).load()
+        error = exc.value.response["Error"]
+        assert error["Code"] == "404"
+        assert error["Message"] == "Not Found"
+        assert not (upload_dir / self.filename_extra).exists()
+        with pytest.raises(ClientError) as exc:
+            self.s3_client.Object(self.Storage.bucket_name, self.filename_extra).load()
+        error = exc.value.response["Error"]
+        assert error["Code"] == "404"
+        assert error["Message"] == "Not Found"
+
+        self.Storage.delete_all_files()
+        assert not (upload_dir / self.filename).exists()
+        assert not (upload_dir / self.filename_extra).exists()