123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117 |
- import os
- import boto3
- from botocore.exceptions import ClientError
- from open_webui.config import (
- STORAGE_PROVIDER,
- S3_ACCESS_KEY_ID,
- S3_SECRET_ACCESS_KEY,
- S3_BUCKET_NAME,
- S3_REGION_NAME,
- S3_ENDPOINT_URL,
- UPLOAD_DIR,
- AppConfig,
- )
- class StorageProvider:
- def __init__(self):
- self.storage_provider = None
- self.client = None
- self.bucket_name = None
- if STORAGE_PROVIDER == "s3":
- self.storage_provider = "s3"
- self.client = boto3.client(
- "s3",
- region_name=S3_REGION_NAME,
- endpoint_url=S3_ENDPOINT_URL,
- aws_access_key_id=S3_ACCESS_KEY_ID,
- aws_secret_access_key=S3_SECRET_ACCESS_KEY,
- )
- self.bucket_name = S3_BUCKET_NAME
- else:
- self.storage_provider = "local"
- def get_storage_provider(self):
- return self.storage_provider
- def upload_file(self, file, filename):
- if self.storage_provider == "s3":
- try:
- bucket = self.bucket_name
- self.client.upload_fileobj(file, bucket, filename)
- return filename
- except ClientError as e:
- raise RuntimeError(f"Error uploading file: {e}")
- else:
- file_path = os.path.join(UPLOAD_DIR, filename)
- os.makedirs(os.path.dirname(file_path), exist_ok=True)
- with open(file_path, "wb") as f:
- f.write(file.read())
- return filename
- def list_files(self):
- if self.storage_provider == "s3":
- try:
- bucket = self.bucket_name
- response = self.client.list_objects_v2(Bucket=bucket)
- if "Contents" in response:
- return [content["Key"] for content in response["Contents"]]
- return []
- except ClientError as e:
- raise RuntimeError(f"Error listing files: {e}")
- else:
- return [
- f
- for f in os.listdir(UPLOAD_DIR)
- if os.path.isfile(os.path.join(UPLOAD_DIR, f))
- ]
- def get_file(self, filename):
- if self.storage_provider == "s3":
- try:
- bucket = self.bucket_name
- file_path = f"/tmp/{filename}"
- self.client.download_file(bucket, filename, file_path)
- return file_path
- except ClientError as e:
- raise RuntimeError(f"Error downloading file: {e}")
- else:
- file_path = os.path.join(UPLOAD_DIR, filename)
- if os.path.isfile(file_path):
- return file_path
- else:
- raise FileNotFoundError(f"File {filename} not found in local storage.")
- def delete_file(self, filename):
- if self.storage_provider == "s3":
- try:
- bucket = self.bucket_name
- self.client.delete_object(Bucket=bucket, Key=filename)
- except ClientError as e:
- raise RuntimeError(f"Error deleting file: {e}")
- else:
- file_path = os.path.join(UPLOAD_DIR, filename)
- if os.path.isfile(file_path):
- os.remove(file_path)
- else:
- raise FileNotFoundError(f"File {filename} not found in local storage.")
- def delete_all_files(self):
- if self.storage_provider == "s3":
- try:
- bucket = self.bucket_name
- response = self.client.list_objects_v2(Bucket=bucket)
- if "Contents" in response:
- for content in response["Contents"]:
- self.client.delete_object(Bucket=bucket, Key=content["Key"])
- except ClientError as e:
- raise RuntimeError(f"Error deleting all files: {e}")
- else:
- for filename in os.listdir(UPLOAD_DIR):
- file_path = os.path.join(UPLOAD_DIR, filename)
- if os.path.isfile(file_path):
- os.remove(file_path)
- Storage = StorageProvider()
|