瀏覽代碼

feat: s3 support

Timothy J. Baek 6 月之前
父節點
當前提交
7984980619
共有 2 個文件被更改,包括 180 次插入132 次删除
  1. 47 46
      backend/open_webui/apps/webui/routers/files.py
  2. 133 86
      backend/open_webui/storage/provider.py

+ 47 - 46
backend/open_webui/apps/webui/routers/files.py

@@ -1,14 +1,19 @@
 import logging
 import logging
 import os
 import os
-import shutil
 import uuid
 import uuid
 from pathlib import Path
 from pathlib import Path
 from typing import Optional
 from typing import Optional
 from pydantic import BaseModel
 from pydantic import BaseModel
 import mimetypes
 import mimetypes
 
 
+from open_webui.storage.provider import Storage
 
 
-from open_webui.apps.webui.models.files import FileForm, FileModel, Files
+from open_webui.apps.webui.models.files import (
+    FileForm,
+    FileModel,
+    FileModelResponse,
+    Files,
+)
 from open_webui.apps.retrieval.main import process_file, ProcessFileForm
 from open_webui.apps.retrieval.main import process_file, ProcessFileForm
 
 
 from open_webui.config import UPLOAD_DIR
 from open_webui.config import UPLOAD_DIR
@@ -44,18 +49,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
         id = str(uuid.uuid4())
         id = str(uuid.uuid4())
         name = filename
         name = filename
         filename = f"{id}_{filename}"
         filename = f"{id}_{filename}"
-        file_path = f"{UPLOAD_DIR}/{filename}"
-
-        contents = file.file.read()
-        if len(contents) == 0:
-            raise HTTPException(
-                status_code=status.HTTP_400_BAD_REQUEST,
-                detail=ERROR_MESSAGES.EMPTY_CONTENT,
-            )
-
-        with open(file_path, "wb") as f:
-            f.write(contents)
-            f.close()
+        contents, file_path = Storage.upload_file(file.file, filename)
 
 
         file = Files.insert_new_file(
         file = Files.insert_new_file(
             user.id,
             user.id,
@@ -101,7 +95,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
 ############################
 ############################
 
 
 
 
-@router.get("/", response_model=list[FileModel])
+@router.get("/", response_model=list[FileModelResponse])
 async def list_files(user=Depends(get_verified_user)):
 async def list_files(user=Depends(get_verified_user)):
     if user.role == "admin":
     if user.role == "admin":
         files = Files.get_files()
         files = Files.get_files()
@@ -118,27 +112,16 @@ async def list_files(user=Depends(get_verified_user)):
 @router.delete("/all")
 @router.delete("/all")
 async def delete_all_files(user=Depends(get_admin_user)):
 async def delete_all_files(user=Depends(get_admin_user)):
     result = Files.delete_all_files()
     result = Files.delete_all_files()
-
     if result:
     if result:
-        folder = f"{UPLOAD_DIR}"
         try:
         try:
-            # Check if the directory exists
-            if os.path.exists(folder):
-                # Iterate over all the files and directories in the specified directory
-                for filename in os.listdir(folder):
-                    file_path = os.path.join(folder, filename)
-                    try:
-                        if os.path.isfile(file_path) or os.path.islink(file_path):
-                            os.unlink(file_path)  # Remove the file or link
-                        elif os.path.isdir(file_path):
-                            shutil.rmtree(file_path)  # Remove the directory
-                    except Exception as e:
-                        print(f"Failed to delete {file_path}. Reason: {e}")
-            else:
-                print(f"The directory {folder} does not exist")
+            Storage.delete_all_files()
         except Exception as e:
         except Exception as e:
-            print(f"Failed to process the directory {folder}. Reason: {e}")
-
+            log.exception(e)
+            log.error(f"Error deleting files")
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
+            )
         return {"message": "All files deleted successfully"}
         return {"message": "All files deleted successfully"}
     else:
     else:
         raise HTTPException(
         raise HTTPException(
@@ -222,21 +205,29 @@ async def update_file_data_content_by_id(
 @router.get("/{id}/content")
 @router.get("/{id}/content")
 async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
 async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
     file = Files.get_file_by_id(id)
     file = Files.get_file_by_id(id)
-
     if file and (file.user_id == user.id or user.role == "admin"):
     if file and (file.user_id == user.id or user.role == "admin"):
-        file_path = Path(file.path)
-
-        # Check if the file already exists in the cache
-        if file_path.is_file():
-            print(f"file_path: {file_path}")
-            headers = {
-                "Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"'
-            }
-            return FileResponse(file_path, headers=headers)
-        else:
+        try:
+            file_path = Storage.get_file(file.path)
+            file_path = Path(file_path)
+
+            # Check if the file already exists in the cache
+            if file_path.is_file():
+                print(f"file_path: {file_path}")
+                headers = {
+                    "Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"'
+                }
+                return FileResponse(file_path, headers=headers)
+            else:
+                raise HTTPException(
+                    status_code=status.HTTP_404_NOT_FOUND,
+                    detail=ERROR_MESSAGES.NOT_FOUND,
+                )
+        except Exception as e:
+            log.exception(e)
+            log.error(f"Error getting file content")
             raise HTTPException(
             raise HTTPException(
-                status_code=status.HTTP_404_NOT_FOUND,
-                detail=ERROR_MESSAGES.NOT_FOUND,
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=ERROR_MESSAGES.DEFAULT("Error getting file content"),
             )
             )
     else:
     else:
         raise HTTPException(
         raise HTTPException(
@@ -252,6 +243,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
     if file and (file.user_id == user.id or user.role == "admin"):
     if file and (file.user_id == user.id or user.role == "admin"):
         file_path = file.path
         file_path = file.path
         if file_path:
         if file_path:
+            file_path = Storage.get_file(file_path)
             file_path = Path(file_path)
             file_path = Path(file_path)
 
 
             # Check if the file already exists in the cache
             # Check if the file already exists in the cache
@@ -298,6 +290,15 @@ async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
     if file and (file.user_id == user.id or user.role == "admin"):
     if file and (file.user_id == user.id or user.role == "admin"):
         result = Files.delete_file_by_id(id)
         result = Files.delete_file_by_id(id)
         if result:
         if result:
+            try:
+                Storage.delete_file(file.filename)
+            except Exception as e:
+                log.exception(e)
+                log.error(f"Error deleting files")
+                raise HTTPException(
+                    status_code=status.HTTP_400_BAD_REQUEST,
+                    detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
+                )
             return {"message": "File deleted successfully"}
             return {"message": "File deleted successfully"}
         else:
         else:
             raise HTTPException(
             raise HTTPException(

+ 133 - 86
backend/open_webui/storage/provider.py

@@ -1,6 +1,12 @@
 import os
 import os
 import boto3
 import boto3
 from botocore.exceptions import ClientError
 from botocore.exceptions import ClientError
+import shutil
+
+
+from typing import BinaryIO, Tuple, Optional, Union
+
+from open_webui.constants import ERROR_MESSAGES
 from open_webui.config import (
 from open_webui.config import (
     STORAGE_PROVIDER,
     STORAGE_PROVIDER,
     S3_ACCESS_KEY_ID,
     S3_ACCESS_KEY_ID,
@@ -9,109 +15,150 @@ from open_webui.config import (
     S3_REGION_NAME,
     S3_REGION_NAME,
     S3_ENDPOINT_URL,
     S3_ENDPOINT_URL,
     UPLOAD_DIR,
     UPLOAD_DIR,
-    AppConfig,
 )
 )
 
 
 
 
+import boto3
+from boto3.s3 import S3Client
+from botocore.exceptions import ClientError
+from typing import BinaryIO, Tuple, Optional
+
+
 class StorageProvider:
 class StorageProvider:
-    def __init__(self):
-        self.storage_provider = None
-        self.client = None
-        self.bucket_name = None
-
-        if STORAGE_PROVIDER == "s3":
-            self.storage_provider = "s3"
-            self.client = boto3.client(
-                "s3",
-                region_name=S3_REGION_NAME,
-                endpoint_url=S3_ENDPOINT_URL,
-                aws_access_key_id=S3_ACCESS_KEY_ID,
-                aws_secret_access_key=S3_SECRET_ACCESS_KEY,
-            )
-            self.bucket_name = S3_BUCKET_NAME
-        else:
-            self.storage_provider = "local"
+    def __init__(self, provider: Optional[str] = None):
+        self.storage_provider: str = provider or STORAGE_PROVIDER
 
 
-    def get_storage_provider(self):
-        return self.storage_provider
+        self.s3_client = None
+        self.s3_bucket_name: Optional[str] = None
 
 
-    def upload_file(self, file, filename):
         if self.storage_provider == "s3":
         if self.storage_provider == "s3":
-            try:
-                bucket = self.bucket_name
-                self.client.upload_fileobj(file, bucket, filename)
-                return filename
-            except ClientError as e:
-                raise RuntimeError(f"Error uploading file: {e}")
+            self._initialize_s3()
+
+    def _initialize_s3(self) -> None:
+        """Initializes the S3 client and bucket name if using S3 storage."""
+        self.s3_client = boto3.client(
+            "s3",
+            region_name=S3_REGION_NAME,
+            endpoint_url=S3_ENDPOINT_URL,
+            aws_access_key_id=S3_ACCESS_KEY_ID,
+            aws_secret_access_key=S3_SECRET_ACCESS_KEY,
+        )
+        self.bucket_name = S3_BUCKET_NAME
+
+    def _upload_to_s3(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
+        """Handles uploading of the file to S3 storage."""
+        if not self.s3_client:
+            raise RuntimeError("S3 Client is not initialized.")
+
+        try:
+            self.s3_client.upload_fileobj(file, self.bucket_name, filename)
+            return file.read(), f"s3://{self.bucket_name}/{filename}"
+        except ClientError as e:
+            raise RuntimeError(f"Error uploading file to S3: {e}")
+
+    def _upload_to_local(self, contents: bytes, filename: str) -> Tuple[bytes, str]:
+        """Handles uploading of the file to local storage."""
+        file_path = f"{UPLOAD_DIR}/{filename}"
+        with open(file_path, "wb") as f:
+            f.write(contents)
+        return contents, file_path
+
+    def _get_file_from_s3(self, file_path: str) -> str:
+        """Handles downloading of the file from S3 storage."""
+        if not self.s3_client:
+            raise RuntimeError("S3 Client is not initialized.")
+
+        try:
+            bucket_name, key = file_path.split("//")[1].split("/")
+            local_file_path = f"{UPLOAD_DIR}/{key}"
+            self.s3_client.download_file(bucket_name, key, local_file_path)
+            return local_file_path
+        except ClientError as e:
+            raise RuntimeError(f"Error downloading file from S3: {e}")
+
+    def _get_file_from_local(self, file_path: str) -> str:
+        """Handles downloading of the file from local storage."""
+        return file_path
+
+    def _delete_from_s3(self, filename: str) -> None:
+        """Handles deletion of the file from S3 storage."""
+        if not self.s3_client:
+            raise RuntimeError("S3 Client is not initialized.")
+
+        try:
+            self.s3_client.delete_object(Bucket=self.bucket_name, Key=filename)
+        except ClientError as e:
+            raise RuntimeError(f"Error deleting file from S3: {e}")
+
+    def _delete_from_local(self, filename: str) -> None:
+        """Handles deletion of the file from local storage."""
+        file_path = f"{UPLOAD_DIR}/{filename}"
+        if os.path.isfile(file_path):
+            os.remove(file_path)
         else:
         else:
-            file_path = os.path.join(UPLOAD_DIR, filename)
-            os.makedirs(os.path.dirname(file_path), exist_ok=True)
-            with open(file_path, "wb") as f:
-                f.write(file.read())
-            return filename
+            raise FileNotFoundError(f"File {filename} not found in local storage.")
 
 
-    def list_files(self):
-        if self.storage_provider == "s3":
-            try:
-                bucket = self.bucket_name
-                response = self.client.list_objects_v2(Bucket=bucket)
-                if "Contents" in response:
-                    return [content["Key"] for content in response["Contents"]]
-                return []
-            except ClientError as e:
-                raise RuntimeError(f"Error listing files: {e}")
+    def _delete_all_from_s3(self) -> None:
+        """Handles deletion of all files from S3 storage."""
+        if not self.s3_client:
+            raise RuntimeError("S3 Client is not initialized.")
+
+        try:
+            response = self.s3_client.list_objects_v2(Bucket=self.bucket_name)
+            if "Contents" in response:
+                for content in response["Contents"]:
+                    self.s3_client.delete_object(
+                        Bucket=self.bucket_name, Key=content["Key"]
+                    )
+        except ClientError as e:
+            raise RuntimeError(f"Error deleting all files from S3: {e}")
+
+    def _delete_all_from_local(self) -> None:
+        """Handles deletion of all files from local storage."""
+        if os.path.exists(UPLOAD_DIR):
+            for filename in os.listdir(UPLOAD_DIR):
+                file_path = os.path.join(UPLOAD_DIR, filename)
+                try:
+                    if os.path.isfile(file_path) or os.path.islink(file_path):
+                        os.unlink(file_path)  # Remove the file or link
+                    elif os.path.isdir(file_path):
+                        shutil.rmtree(file_path)  # Remove the directory
+                except Exception as e:
+                    print(f"Failed to delete {file_path}. Reason: {e}")
         else:
         else:
-            return [
-                f
-                for f in os.listdir(UPLOAD_DIR)
-                if os.path.isfile(os.path.join(UPLOAD_DIR, f))
-            ]
+            raise FileNotFoundError(
+                f"Directory {UPLOAD_DIR} not found in local storage."
+            )
+
+    def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
+        """Uploads a file either to S3 or the local file system."""
+        contents = file.read()
+        if not contents:
+            raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
 
 
-    def get_file(self, filename):
         if self.storage_provider == "s3":
         if self.storage_provider == "s3":
-            try:
-                bucket = self.bucket_name
-                file_path = f"/tmp/{filename}"
-                self.client.download_file(bucket, filename, file_path)
-                return file_path
-            except ClientError as e:
-                raise RuntimeError(f"Error downloading file: {e}")
-        else:
-            file_path = os.path.join(UPLOAD_DIR, filename)
-            if os.path.isfile(file_path):
-                return file_path
-            else:
-                raise FileNotFoundError(f"File {filename} not found in local storage.")
+            return self._upload_to_s3(file, filename)
+        return self._upload_to_local(contents, filename)
 
 
-    def delete_file(self, filename):
+    def get_file(self, file_path: str) -> str:
+        """Downloads a file either from S3 or the local file system and returns the file path."""
         if self.storage_provider == "s3":
         if self.storage_provider == "s3":
-            try:
-                bucket = self.bucket_name
-                self.client.delete_object(Bucket=bucket, Key=filename)
-            except ClientError as e:
-                raise RuntimeError(f"Error deleting file: {e}")
+            return self._get_file_from_s3(file_path)
+        return self._get_file_from_local(file_path)
+
+    def delete_file(self, filename: str) -> None:
+        """Deletes a file either from S3 or the local file system."""
+        if self.storage_provider == "s3":
+            self._delete_from_s3(filename)
         else:
         else:
-            file_path = os.path.join(UPLOAD_DIR, filename)
-            if os.path.isfile(file_path):
-                os.remove(file_path)
-            else:
-                raise FileNotFoundError(f"File {filename} not found in local storage.")
+            self._delete_from_local(filename)
 
 
-    def delete_all_files(self):
+    def delete_all_files(self) -> None:
+        """Deletes all files from the storage."""
         if self.storage_provider == "s3":
         if self.storage_provider == "s3":
-            try:
-                bucket = self.bucket_name
-                response = self.client.list_objects_v2(Bucket=bucket)
-                if "Contents" in response:
-                    for content in response["Contents"]:
-                        self.client.delete_object(Bucket=bucket, Key=content["Key"])
-            except ClientError as e:
-                raise RuntimeError(f"Error deleting all files: {e}")
+            self._delete_all_from_s3()
         else:
         else:
-            for filename in os.listdir(UPLOAD_DIR):
-                file_path = os.path.join(UPLOAD_DIR, filename)
-                if os.path.isfile(file_path):
-                    os.remove(file_path)
+            self._delete_all_from_local()
 
 
 
 
-Storage = StorageProvider()
+Storage = StorageProvider(provider=STORAGE_PROVIDER)