Browse Source

use key_prefix in rest of S3StorageProvider

Patrick Deniso 2 months ago
parent
commit
7f82476926
1 changed files with 18 additions and 8 deletions
  1. 18 8
      backend/open_webui/storage/provider.py

+ 18 - 8
backend/open_webui/storage/provider.py

@@ -94,35 +94,36 @@ class S3StorageProvider(StorageProvider):
             aws_secret_access_key=S3_SECRET_ACCESS_KEY,
             aws_secret_access_key=S3_SECRET_ACCESS_KEY,
         )
         )
         self.bucket_name = S3_BUCKET_NAME
         self.bucket_name = S3_BUCKET_NAME
+        self.key_prefix = S3_KEY_PREFIX
 
 
     def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
     def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
         """Handles uploading of the file to S3 storage."""
         """Handles uploading of the file to S3 storage."""
         _, file_path = LocalStorageProvider.upload_file(file, filename)
         _, file_path = LocalStorageProvider.upload_file(file, filename)
         try:
         try:
-            s3_key = os.path.join(S3_KEY_PREFIX, filename)
+            s3_key = os.path.join(self.key_prefix, filename)
             self.s3_client.upload_file(file_path, self.bucket_name, s3_key)
             self.s3_client.upload_file(file_path, self.bucket_name, s3_key)
             return (
             return (
                 open(file_path, "rb").read(),
                 open(file_path, "rb").read(),
-                "s3://" + self.bucket_name + "/" + filename,
+                "s3://" + self.bucket_name + "/" + s3_key,
             )
             )
         except ClientError as e:
         except ClientError as e:
             raise RuntimeError(f"Error uploading file to S3: {e}")
             raise RuntimeError(f"Error uploading file to S3: {e}")
-
+    
     def get_file(self, file_path: str) -> str:
     def get_file(self, file_path: str) -> str:
         """Handles downloading of the file from S3 storage."""
         """Handles downloading of the file from S3 storage."""
         try:
         try:
-            bucket_name, key = file_path.split("//")[1].split("/")
-            local_file_path = f"{UPLOAD_DIR}/{key}"
-            self.s3_client.download_file(bucket_name, key, local_file_path)
+            s3_key = self._extract_s3_key(file_path)
+            local_file_path = self._get_local_file_path(s3_key)
+            self.s3_client.download_file(self.bucket_name, s3_key, local_file_path)
             return local_file_path
             return local_file_path
         except ClientError as e:
         except ClientError as e:
             raise RuntimeError(f"Error downloading file from S3: {e}")
             raise RuntimeError(f"Error downloading file from S3: {e}")
 
 
     def delete_file(self, file_path: str) -> None:
     def delete_file(self, file_path: str) -> None:
         """Handles deletion of the file from S3 storage."""
         """Handles deletion of the file from S3 storage."""
-        filename = file_path.split("/")[-1]
         try:
         try:
-            self.s3_client.delete_object(Bucket=self.bucket_name, Key=filename)
+            s3_key = self._extract_s3_key(file_path)
+            self.s3_client.delete_object(Bucket=self.bucket_name, Key=s3_key)
         except ClientError as e:
         except ClientError as e:
             raise RuntimeError(f"Error deleting file from S3: {e}")
             raise RuntimeError(f"Error deleting file from S3: {e}")
 
 
@@ -135,6 +136,9 @@ class S3StorageProvider(StorageProvider):
             response = self.s3_client.list_objects_v2(Bucket=self.bucket_name)
             response = self.s3_client.list_objects_v2(Bucket=self.bucket_name)
             if "Contents" in response:
             if "Contents" in response:
                 for content in response["Contents"]:
                 for content in response["Contents"]:
+                    # Skip objects that were not uploaded from open-webui in the first place
+                    if not content["Key"].startswith(self.key_prefix): continue
+
                     self.s3_client.delete_object(
                     self.s3_client.delete_object(
                         Bucket=self.bucket_name, Key=content["Key"]
                         Bucket=self.bucket_name, Key=content["Key"]
                     )
                     )
@@ -144,6 +148,12 @@ class S3StorageProvider(StorageProvider):
         # Always delete from local storage
         # Always delete from local storage
         LocalStorageProvider.delete_all_files()
         LocalStorageProvider.delete_all_files()
 
 
+    # The s3 key is the name assigned to an object. It excludes the bucket name, but includes the internal path and the file name.
+    def _extract_s3_key(self, full_file_path: str) -> str:
+        return ''.join(full_file_path.split("//")[1].split("/")[1:])
+    
+    def _get_local_file_path(self, s3_key: str) -> str:
+        return f"{UPLOAD_DIR}/{s3_key.split('/')[-1]}"
 
 
 class GCSStorageProvider(StorageProvider):
 class GCSStorageProvider(StorageProvider):
     def __init__(self):
     def __init__(self):