浏览代码

chore: format backend

Timothy Jaeryang Baek 3 月之前
父节点
当前提交
8d3c73aed5

+ 3 - 1
backend/open_webui/config.py

@@ -663,7 +663,9 @@ S3_BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", None)
 S3_ENDPOINT_URL = os.environ.get("S3_ENDPOINT_URL", None)
 S3_ENDPOINT_URL = os.environ.get("S3_ENDPOINT_URL", None)
 
 
 GCS_BUCKET_NAME = os.environ.get("GCS_BUCKET_NAME", None)
 GCS_BUCKET_NAME = os.environ.get("GCS_BUCKET_NAME", None)
-GOOGLE_APPLICATION_CREDENTIALS_JSON = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS_JSON", None)
+GOOGLE_APPLICATION_CREDENTIALS_JSON = os.environ.get(
+    "GOOGLE_APPLICATION_CREDENTIALS_JSON", None
+)
 
 
 ####################################
 ####################################
 # File Upload DIR
 # File Upload DIR

+ 14 - 9
backend/open_webui/storage/provider.py

@@ -142,19 +142,22 @@ class S3StorageProvider(StorageProvider):
         # Always delete from local storage
         # Always delete from local storage
         LocalStorageProvider.delete_all_files()
         LocalStorageProvider.delete_all_files()
 
 
+
 class GCSStorageProvider(StorageProvider):
 class GCSStorageProvider(StorageProvider):
     def __init__(self):
     def __init__(self):
         self.bucket_name = GCS_BUCKET_NAME
         self.bucket_name = GCS_BUCKET_NAME
-    
+
         if GOOGLE_APPLICATION_CREDENTIALS_JSON:
         if GOOGLE_APPLICATION_CREDENTIALS_JSON:
-            self.gcs_client = storage.Client.from_service_account_info(info=json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON))
+            self.gcs_client = storage.Client.from_service_account_info(
+                info=json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
+            )
         else:
         else:
             # if no credentials json is provided, credentials will be picked up from the environment
             # if no credentials json is provided, credentials will be picked up from the environment
             # if running on local environment, credentials would be user credentials
             # if running on local environment, credentials would be user credentials
             # if running on a Compute Engine instance, credentials would be from Google Metadata server
             # if running on a Compute Engine instance, credentials would be from Google Metadata server
             self.gcs_client = storage.Client()
             self.gcs_client = storage.Client()
         self.bucket = self.gcs_client.bucket(GCS_BUCKET_NAME)
         self.bucket = self.gcs_client.bucket(GCS_BUCKET_NAME)
-    
+
     def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
     def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
         """Handles uploading of the file to GCS storage."""
         """Handles uploading of the file to GCS storage."""
         contents, file_path = LocalStorageProvider.upload_file(file, filename)
         contents, file_path = LocalStorageProvider.upload_file(file, filename)
@@ -165,19 +168,19 @@ class GCSStorageProvider(StorageProvider):
         except GoogleCloudError as e:
         except GoogleCloudError as e:
             raise RuntimeError(f"Error uploading file to GCS: {e}")
             raise RuntimeError(f"Error uploading file to GCS: {e}")
 
 
-    def get_file(self, file_path:str) -> str:
+    def get_file(self, file_path: str) -> str:
         """Handles downloading of the file from GCS storage."""
         """Handles downloading of the file from GCS storage."""
         try:
         try:
             filename = file_path.removeprefix("gs://").split("/")[1]
             filename = file_path.removeprefix("gs://").split("/")[1]
-            local_file_path = f"{UPLOAD_DIR}/{filename}"            
+            local_file_path = f"{UPLOAD_DIR}/{filename}"
             blob = self.bucket.get_blob(filename)
             blob = self.bucket.get_blob(filename)
             blob.download_to_filename(local_file_path)
             blob.download_to_filename(local_file_path)
 
 
             return local_file_path
             return local_file_path
         except NotFound as e:
         except NotFound as e:
             raise RuntimeError(f"Error downloading file from GCS: {e}")
             raise RuntimeError(f"Error downloading file from GCS: {e}")
-    
-    def delete_file(self, file_path:str) -> None:
+
+    def delete_file(self, file_path: str) -> None:
         """Handles deletion of the file from GCS storage."""
         """Handles deletion of the file from GCS storage."""
         try:
         try:
             filename = file_path.removeprefix("gs://").split("/")[1]
             filename = file_path.removeprefix("gs://").split("/")[1]
@@ -185,7 +188,7 @@ class GCSStorageProvider(StorageProvider):
             blob.delete()
             blob.delete()
         except NotFound as e:
         except NotFound as e:
             raise RuntimeError(f"Error deleting file from GCS: {e}")
             raise RuntimeError(f"Error deleting file from GCS: {e}")
-        
+
         # Always delete from local storage
         # Always delete from local storage
         LocalStorageProvider.delete_file(file_path)
         LocalStorageProvider.delete_file(file_path)
 
 
@@ -199,10 +202,11 @@ class GCSStorageProvider(StorageProvider):
 
 
         except NotFound as e:
         except NotFound as e:
             raise RuntimeError(f"Error deleting all files from GCS: {e}")
             raise RuntimeError(f"Error deleting all files from GCS: {e}")
-        
+
         # Always delete from local storage
         # Always delete from local storage
         LocalStorageProvider.delete_all_files()
         LocalStorageProvider.delete_all_files()
 
 
+
 def get_storage_provider(storage_provider: str):
 def get_storage_provider(storage_provider: str):
     if storage_provider == "local":
     if storage_provider == "local":
         Storage = LocalStorageProvider()
         Storage = LocalStorageProvider()
@@ -214,4 +218,5 @@ def get_storage_provider(storage_provider: str):
         raise RuntimeError(f"Unsupported storage provider: {storage_provider}")
         raise RuntimeError(f"Unsupported storage provider: {storage_provider}")
     return Storage
     return Storage
 
 
+
 Storage = get_storage_provider(STORAGE_PROVIDER)
 Storage = get_storage_provider(STORAGE_PROVIDER)

+ 8 - 5
backend/open_webui/test/apps/webui/storage/test_provider.py

@@ -104,7 +104,6 @@ class TestS3StorageProvider:
         self.file_bytesio_empty = io.BytesIO()
         self.file_bytesio_empty = io.BytesIO()
         super().__init__()
         super().__init__()
 
 
-
     def test_upload_file(self, monkeypatch, tmp_path):
     def test_upload_file(self, monkeypatch, tmp_path):
         upload_dir = mock_upload_dir(monkeypatch, tmp_path)
         upload_dir = mock_upload_dir(monkeypatch, tmp_path)
         # S3 checks
         # S3 checks
@@ -182,6 +181,7 @@ class TestS3StorageProvider:
         assert not (upload_dir / self.filename).exists()
         assert not (upload_dir / self.filename).exists()
         assert not (upload_dir / self.filename_extra).exists()
         assert not (upload_dir / self.filename_extra).exists()
 
 
+
 class TestGCSStorageProvider:
 class TestGCSStorageProvider:
     Storage = provider.GCSStorageProvider()
     Storage = provider.GCSStorageProvider()
     Storage.bucket_name = "my-bucket"
     Storage.bucket_name = "my-bucket"
@@ -202,15 +202,15 @@ class TestGCSStorageProvider:
         bucket = gcs_client.bucket(self.Storage.bucket_name)
         bucket = gcs_client.bucket(self.Storage.bucket_name)
         bucket.create()
         bucket.create()
         self.Storage.gcs_client, self.Storage.bucket = gcs_client, bucket
         self.Storage.gcs_client, self.Storage.bucket = gcs_client, bucket
-        yield 
+        yield
         bucket.delete(force=True)
         bucket.delete(force=True)
         server.stop()
         server.stop()
-    
+
     def test_upload_file(self, monkeypatch, tmp_path, setup):
     def test_upload_file(self, monkeypatch, tmp_path, setup):
         upload_dir = mock_upload_dir(monkeypatch, tmp_path)
         upload_dir = mock_upload_dir(monkeypatch, tmp_path)
         # catch error if bucket does not exist
         # catch error if bucket does not exist
         with pytest.raises(Exception):
         with pytest.raises(Exception):
-            self.Storage.bucket = monkeypatch(self.Storage, "bucket", None)  
+            self.Storage.bucket = monkeypatch(self.Storage, "bucket", None)
             self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
             self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
         contents, gcs_file_path = self.Storage.upload_file(
         contents, gcs_file_path = self.Storage.upload_file(
             io.BytesIO(self.file_content), self.filename
             io.BytesIO(self.file_content), self.filename
@@ -261,7 +261,10 @@ class TestGCSStorageProvider:
         object = self.Storage.bucket.get_blob(self.filename_extra)
         object = self.Storage.bucket.get_blob(self.filename_extra)
         assert (upload_dir / self.filename_extra).exists()
         assert (upload_dir / self.filename_extra).exists()
         assert (upload_dir / self.filename_extra).read_bytes() == self.file_content
         assert (upload_dir / self.filename_extra).read_bytes() == self.file_content
-        assert self.Storage.bucket.get_blob(self.filename_extra).name == self.filename_extra
+        assert (
+            self.Storage.bucket.get_blob(self.filename_extra).name
+            == self.filename_extra
+        )
         assert self.file_content == object.download_as_bytes()
         assert self.file_content == object.download_as_bytes()
 
 
         self.Storage.delete_all_files()
         self.Storage.delete_all_files()