provider.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import os
  2. import boto3
  3. from botocore.exceptions import ClientError
  4. from open_webui.config import (
  5. STORAGE_PROVIDER,
  6. S3_ACCESS_KEY_ID,
  7. S3_SECRET_ACCESS_KEY,
  8. S3_BUCKET_NAME,
  9. S3_REGION_NAME,
  10. S3_ENDPOINT_URL,
  11. UPLOAD_DIR,
  12. AppConfig,
  13. )
  14. class StorageProvider:
  15. def __init__(self):
  16. self.storage_provider = None
  17. self.client = None
  18. self.bucket_name = None
  19. if STORAGE_PROVIDER == "s3":
  20. self.storage_provider = "s3"
  21. self.client = boto3.client(
  22. "s3",
  23. region_name=S3_REGION_NAME,
  24. endpoint_url=S3_ENDPOINT_URL,
  25. aws_access_key_id=S3_ACCESS_KEY_ID,
  26. aws_secret_access_key=S3_SECRET_ACCESS_KEY,
  27. )
  28. self.bucket_name = S3_BUCKET_NAME
  29. else:
  30. self.storage_provider = "local"
  31. def get_storage_provider(self):
  32. return self.storage_provider
  33. def upload_file(self, file, filename):
  34. if self.storage_provider == "s3":
  35. try:
  36. bucket = self.bucket_name
  37. self.client.upload_fileobj(file, bucket, filename)
  38. return filename
  39. except ClientError as e:
  40. raise RuntimeError(f"Error uploading file: {e}")
  41. else:
  42. file_path = os.path.join(UPLOAD_DIR, filename)
  43. os.makedirs(os.path.dirname(file_path), exist_ok=True)
  44. with open(file_path, "wb") as f:
  45. f.write(file.read())
  46. return filename
  47. def list_files(self):
  48. if self.storage_provider == "s3":
  49. try:
  50. bucket = self.bucket_name
  51. response = self.client.list_objects_v2(Bucket=bucket)
  52. if "Contents" in response:
  53. return [content["Key"] for content in response["Contents"]]
  54. return []
  55. except ClientError as e:
  56. raise RuntimeError(f"Error listing files: {e}")
  57. else:
  58. return [
  59. f
  60. for f in os.listdir(UPLOAD_DIR)
  61. if os.path.isfile(os.path.join(UPLOAD_DIR, f))
  62. ]
  63. def get_file(self, filename):
  64. if self.storage_provider == "s3":
  65. try:
  66. bucket = self.bucket_name
  67. file_path = f"/tmp/{filename}"
  68. self.client.download_file(bucket, filename, file_path)
  69. return file_path
  70. except ClientError as e:
  71. raise RuntimeError(f"Error downloading file: {e}")
  72. else:
  73. file_path = os.path.join(UPLOAD_DIR, filename)
  74. if os.path.isfile(file_path):
  75. return file_path
  76. else:
  77. raise FileNotFoundError(f"File {filename} not found in local storage.")
  78. def delete_file(self, filename):
  79. if self.storage_provider == "s3":
  80. try:
  81. bucket = self.bucket_name
  82. self.client.delete_object(Bucket=bucket, Key=filename)
  83. except ClientError as e:
  84. raise RuntimeError(f"Error deleting file: {e}")
  85. else:
  86. file_path = os.path.join(UPLOAD_DIR, filename)
  87. if os.path.isfile(file_path):
  88. os.remove(file_path)
  89. else:
  90. raise FileNotFoundError(f"File {filename} not found in local storage.")
  91. def delete_all_files(self):
  92. if self.storage_provider == "s3":
  93. try:
  94. bucket = self.bucket_name
  95. response = self.client.list_objects_v2(Bucket=bucket)
  96. if "Contents" in response:
  97. for content in response["Contents"]:
  98. self.client.delete_object(Bucket=bucket, Key=content["Key"])
  99. except ClientError as e:
  100. raise RuntimeError(f"Error deleting all files: {e}")
  101. else:
  102. for filename in os.listdir(UPLOAD_DIR):
  103. file_path = os.path.join(UPLOAD_DIR, filename)
  104. if os.path.isfile(file_path):
  105. os.remove(file_path)
  106. Storage = StorageProvider()