Pārlūkot izejas kodu

feat: s3 support

Timothy J. Baek 6 mēneši atpakaļ
vecāks
revīzija
7984980619

+ 47 - 46
backend/open_webui/apps/webui/routers/files.py

@@ -1,14 +1,19 @@
 import logging
 import os
-import shutil
 import uuid
 from pathlib import Path
 from typing import Optional
 from pydantic import BaseModel
 import mimetypes
 
+from open_webui.storage.provider import Storage
 
-from open_webui.apps.webui.models.files import FileForm, FileModel, Files
+from open_webui.apps.webui.models.files import (
+    FileForm,
+    FileModel,
+    FileModelResponse,
+    Files,
+)
 from open_webui.apps.retrieval.main import process_file, ProcessFileForm
 
 from open_webui.config import UPLOAD_DIR
@@ -44,18 +49,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
         id = str(uuid.uuid4())
         name = filename
         filename = f"{id}_{filename}"
-        file_path = f"{UPLOAD_DIR}/{filename}"
-
-        contents = file.file.read()
-        if len(contents) == 0:
-            raise HTTPException(
-                status_code=status.HTTP_400_BAD_REQUEST,
-                detail=ERROR_MESSAGES.EMPTY_CONTENT,
-            )
-
-        with open(file_path, "wb") as f:
-            f.write(contents)
-            f.close()
+        contents, file_path = Storage.upload_file(file.file, filename)
 
         file = Files.insert_new_file(
             user.id,
@@ -101,7 +95,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
 ############################
 
 
-@router.get("/", response_model=list[FileModel])
+@router.get("/", response_model=list[FileModelResponse])
 async def list_files(user=Depends(get_verified_user)):
     if user.role == "admin":
         files = Files.get_files()
@@ -118,27 +112,16 @@ async def list_files(user=Depends(get_verified_user)):
 @router.delete("/all")
 async def delete_all_files(user=Depends(get_admin_user)):
     result = Files.delete_all_files()
-
     if result:
-        folder = f"{UPLOAD_DIR}"
         try:
-            # Check if the directory exists
-            if os.path.exists(folder):
-                # Iterate over all the files and directories in the specified directory
-                for filename in os.listdir(folder):
-                    file_path = os.path.join(folder, filename)
-                    try:
-                        if os.path.isfile(file_path) or os.path.islink(file_path):
-                            os.unlink(file_path)  # Remove the file or link
-                        elif os.path.isdir(file_path):
-                            shutil.rmtree(file_path)  # Remove the directory
-                    except Exception as e:
-                        print(f"Failed to delete {file_path}. Reason: {e}")
-            else:
-                print(f"The directory {folder} does not exist")
+            Storage.delete_all_files()
         except Exception as e:
-            print(f"Failed to process the directory {folder}. Reason: {e}")
-
+            log.exception(e)
+            log.error(f"Error deleting files")
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
+            )
         return {"message": "All files deleted successfully"}
     else:
         raise HTTPException(
@@ -222,21 +205,29 @@ async def update_file_data_content_by_id(
 @router.get("/{id}/content")
 async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
     file = Files.get_file_by_id(id)
-
     if file and (file.user_id == user.id or user.role == "admin"):
-        file_path = Path(file.path)
-
-        # Check if the file already exists in the cache
-        if file_path.is_file():
-            print(f"file_path: {file_path}")
-            headers = {
-                "Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"'
-            }
-            return FileResponse(file_path, headers=headers)
-        else:
+        try:
+            file_path = Storage.get_file(file.path)
+            file_path = Path(file_path)
+
+            # Check if the file already exists in the cache
+            if file_path.is_file():
+                print(f"file_path: {file_path}")
+                headers = {
+                    "Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"'
+                }
+                return FileResponse(file_path, headers=headers)
+            else:
+                raise HTTPException(
+                    status_code=status.HTTP_404_NOT_FOUND,
+                    detail=ERROR_MESSAGES.NOT_FOUND,
+                )
+        except Exception as e:
+            log.exception(e)
+            log.error(f"Error getting file content")
             raise HTTPException(
-                status_code=status.HTTP_404_NOT_FOUND,
-                detail=ERROR_MESSAGES.NOT_FOUND,
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=ERROR_MESSAGES.DEFAULT("Error getting file content"),
             )
     else:
         raise HTTPException(
@@ -252,6 +243,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
     if file and (file.user_id == user.id or user.role == "admin"):
         file_path = file.path
         if file_path:
+            file_path = Storage.get_file(file_path)
             file_path = Path(file_path)
 
             # Check if the file already exists in the cache
@@ -298,6 +290,15 @@ async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
     if file and (file.user_id == user.id or user.role == "admin"):
         result = Files.delete_file_by_id(id)
         if result:
+            try:
+                Storage.delete_file(file.filename)
+            except Exception as e:
+                log.exception(e)
+                log.error(f"Error deleting files")
+                raise HTTPException(
+                    status_code=status.HTTP_400_BAD_REQUEST,
+                    detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
+                )
             return {"message": "File deleted successfully"}
         else:
             raise HTTPException(

+ 133 - 86
backend/open_webui/storage/provider.py

@@ -1,6 +1,12 @@
 import os
 import boto3
 from botocore.exceptions import ClientError
+import shutil
+
+
+from typing import BinaryIO, Tuple, Optional, Union
+
+from open_webui.constants import ERROR_MESSAGES
 from open_webui.config import (
     STORAGE_PROVIDER,
     S3_ACCESS_KEY_ID,
@@ -9,109 +15,150 @@ from open_webui.config import (
     S3_REGION_NAME,
     S3_ENDPOINT_URL,
     UPLOAD_DIR,
-    AppConfig,
 )
 
 
+import boto3
+from boto3.s3 import S3Client
+from botocore.exceptions import ClientError
+from typing import BinaryIO, Tuple, Optional
+
+
 class StorageProvider:
-    def __init__(self):
-        self.storage_provider = None
-        self.client = None
-        self.bucket_name = None
-
-        if STORAGE_PROVIDER == "s3":
-            self.storage_provider = "s3"
-            self.client = boto3.client(
-                "s3",
-                region_name=S3_REGION_NAME,
-                endpoint_url=S3_ENDPOINT_URL,
-                aws_access_key_id=S3_ACCESS_KEY_ID,
-                aws_secret_access_key=S3_SECRET_ACCESS_KEY,
-            )
-            self.bucket_name = S3_BUCKET_NAME
-        else:
-            self.storage_provider = "local"
+    def __init__(self, provider: Optional[str] = None):
+        self.storage_provider: str = provider or STORAGE_PROVIDER
 
-    def get_storage_provider(self):
-        return self.storage_provider
+        self.s3_client = None
+        self.s3_bucket_name: Optional[str] = None
 
-    def upload_file(self, file, filename):
         if self.storage_provider == "s3":
-            try:
-                bucket = self.bucket_name
-                self.client.upload_fileobj(file, bucket, filename)
-                return filename
-            except ClientError as e:
-                raise RuntimeError(f"Error uploading file: {e}")
+            self._initialize_s3()
+
+    def _initialize_s3(self) -> None:
+        """Initializes the S3 client and bucket name if using S3 storage."""
+        self.s3_client = boto3.client(
+            "s3",
+            region_name=S3_REGION_NAME,
+            endpoint_url=S3_ENDPOINT_URL,
+            aws_access_key_id=S3_ACCESS_KEY_ID,
+            aws_secret_access_key=S3_SECRET_ACCESS_KEY,
+        )
+        self.bucket_name = S3_BUCKET_NAME
+
+    def _upload_to_s3(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
+        """Handles uploading of the file to S3 storage."""
+        if not self.s3_client:
+            raise RuntimeError("S3 Client is not initialized.")
+
+        try:
+            self.s3_client.upload_fileobj(file, self.bucket_name, filename)
+            return file.read(), f"s3://{self.bucket_name}/{filename}"
+        except ClientError as e:
+            raise RuntimeError(f"Error uploading file to S3: {e}")
+
+    def _upload_to_local(self, contents: bytes, filename: str) -> Tuple[bytes, str]:
+        """Handles uploading of the file to local storage."""
+        file_path = f"{UPLOAD_DIR}/{filename}"
+        with open(file_path, "wb") as f:
+            f.write(contents)
+        return contents, file_path
+
+    def _get_file_from_s3(self, file_path: str) -> str:
+        """Handles downloading of the file from S3 storage."""
+        if not self.s3_client:
+            raise RuntimeError("S3 Client is not initialized.")
+
+        try:
+            bucket_name, key = file_path.split("//")[1].split("/")
+            local_file_path = f"{UPLOAD_DIR}/{key}"
+            self.s3_client.download_file(bucket_name, key, local_file_path)
+            return local_file_path
+        except ClientError as e:
+            raise RuntimeError(f"Error downloading file from S3: {e}")
+
+    def _get_file_from_local(self, file_path: str) -> str:
+        """Handles downloading of the file from local storage."""
+        return file_path
+
+    def _delete_from_s3(self, filename: str) -> None:
+        """Handles deletion of the file from S3 storage."""
+        if not self.s3_client:
+            raise RuntimeError("S3 Client is not initialized.")
+
+        try:
+            self.s3_client.delete_object(Bucket=self.bucket_name, Key=filename)
+        except ClientError as e:
+            raise RuntimeError(f"Error deleting file from S3: {e}")
+
+    def _delete_from_local(self, filename: str) -> None:
+        """Handles deletion of the file from local storage."""
+        file_path = f"{UPLOAD_DIR}/{filename}"
+        if os.path.isfile(file_path):
+            os.remove(file_path)
         else:
-            file_path = os.path.join(UPLOAD_DIR, filename)
-            os.makedirs(os.path.dirname(file_path), exist_ok=True)
-            with open(file_path, "wb") as f:
-                f.write(file.read())
-            return filename
+            raise FileNotFoundError(f"File {filename} not found in local storage.")
 
-    def list_files(self):
-        if self.storage_provider == "s3":
-            try:
-                bucket = self.bucket_name
-                response = self.client.list_objects_v2(Bucket=bucket)
-                if "Contents" in response:
-                    return [content["Key"] for content in response["Contents"]]
-                return []
-            except ClientError as e:
-                raise RuntimeError(f"Error listing files: {e}")
+    def _delete_all_from_s3(self) -> None:
+        """Handles deletion of all files from S3 storage."""
+        if not self.s3_client:
+            raise RuntimeError("S3 Client is not initialized.")
+
+        try:
+            response = self.s3_client.list_objects_v2(Bucket=self.bucket_name)
+            if "Contents" in response:
+                for content in response["Contents"]:
+                    self.s3_client.delete_object(
+                        Bucket=self.bucket_name, Key=content["Key"]
+                    )
+        except ClientError as e:
+            raise RuntimeError(f"Error deleting all files from S3: {e}")
+
+    def _delete_all_from_local(self) -> None:
+        """Handles deletion of all files from local storage."""
+        if os.path.exists(UPLOAD_DIR):
+            for filename in os.listdir(UPLOAD_DIR):
+                file_path = os.path.join(UPLOAD_DIR, filename)
+                try:
+                    if os.path.isfile(file_path) or os.path.islink(file_path):
+                        os.unlink(file_path)  # Remove the file or link
+                    elif os.path.isdir(file_path):
+                        shutil.rmtree(file_path)  # Remove the directory
+                except Exception as e:
+                    print(f"Failed to delete {file_path}. Reason: {e}")
         else:
-            return [
-                f
-                for f in os.listdir(UPLOAD_DIR)
-                if os.path.isfile(os.path.join(UPLOAD_DIR, f))
-            ]
+            raise FileNotFoundError(
+                f"Directory {UPLOAD_DIR} not found in local storage."
+            )
+
+    def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
+        """Uploads a file either to S3 or the local file system."""
+        contents = file.read()
+        if not contents:
+            raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
 
-    def get_file(self, filename):
         if self.storage_provider == "s3":
-            try:
-                bucket = self.bucket_name
-                file_path = f"/tmp/{filename}"
-                self.client.download_file(bucket, filename, file_path)
-                return file_path
-            except ClientError as e:
-                raise RuntimeError(f"Error downloading file: {e}")
-        else:
-            file_path = os.path.join(UPLOAD_DIR, filename)
-            if os.path.isfile(file_path):
-                return file_path
-            else:
-                raise FileNotFoundError(f"File {filename} not found in local storage.")
+            return self._upload_to_s3(file, filename)
+        return self._upload_to_local(contents, filename)
 
-    def delete_file(self, filename):
+    def get_file(self, file_path: str) -> str:
+        """Downloads a file either from S3 or the local file system and returns the file path."""
         if self.storage_provider == "s3":
-            try:
-                bucket = self.bucket_name
-                self.client.delete_object(Bucket=bucket, Key=filename)
-            except ClientError as e:
-                raise RuntimeError(f"Error deleting file: {e}")
+            return self._get_file_from_s3(file_path)
+        return self._get_file_from_local(file_path)
+
+    def delete_file(self, filename: str) -> None:
+        """Deletes a file either from S3 or the local file system."""
+        if self.storage_provider == "s3":
+            self._delete_from_s3(filename)
         else:
-            file_path = os.path.join(UPLOAD_DIR, filename)
-            if os.path.isfile(file_path):
-                os.remove(file_path)
-            else:
-                raise FileNotFoundError(f"File {filename} not found in local storage.")
+            self._delete_from_local(filename)
 
-    def delete_all_files(self):
+    def delete_all_files(self) -> None:
+        """Deletes all files from the storage."""
         if self.storage_provider == "s3":
-            try:
-                bucket = self.bucket_name
-                response = self.client.list_objects_v2(Bucket=bucket)
-                if "Contents" in response:
-                    for content in response["Contents"]:
-                        self.client.delete_object(Bucket=bucket, Key=content["Key"])
-            except ClientError as e:
-                raise RuntimeError(f"Error deleting all files: {e}")
+            self._delete_all_from_s3()
         else:
-            for filename in os.listdir(UPLOAD_DIR):
-                file_path = os.path.join(UPLOAD_DIR, filename)
-                if os.path.isfile(file_path):
-                    os.remove(file_path)
+            self._delete_all_from_local()
 
 
-Storage = StorageProvider()
+Storage = StorageProvider(provider=STORAGE_PROVIDER)