Browse Source

test_upload_file working, added gcp-storage-emulator, updated gcs client instantiation

kahghi 3 months ago
parent
commit
4b56c15a3f

+ 3 - 2
backend/open_webui/storage/provider.py

@@ -144,12 +144,13 @@ class S3StorageProvider(StorageProvider):
 
 class GCSStorageProvider(StorageProvider):
     def __init__(self):
+        self.bucket_name = GCS_BUCKET_NAME
+    
         if GCS_BUCKET_NAME and GOOGLE_APPLICATION_CREDENTIALS_JSON:
             self.gcs_client = storage.Client.from_service_account_info(info=json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON))
-        if GCS_BUCKET_NAME and not GOOGLE_APPLICATION_CREDENTIALS_JSON:
+        else:
             # defaults to environment, be it GCE VM or user credentials
             self.gcs_client = storage.Client()
-        self.bucket_name = GCS_BUCKET_NAME
         self.bucket = self.gcs_client.bucket(GCS_BUCKET_NAME)
     
     def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:

+ 109 - 1
backend/open_webui/test/apps/webui/storage/test_provider.py

@@ -1,10 +1,12 @@
 import io
-
+import os
 import boto3
 import pytest
 from botocore.exceptions import ClientError
 from moto import mock_aws
 from open_webui.storage import provider
+from gcp_storage_emulator.server import create_server
+from google.cloud import storage
 
 
 def mock_upload_dir(monkeypatch, tmp_path):
@@ -19,6 +21,7 @@ def test_imports():
     provider.StorageProvider
     provider.LocalStorageProvider
     provider.S3StorageProvider
+    provider.GCSStorageProvider
     provider.Storage
 
 
@@ -27,6 +30,8 @@ def test_get_storage_provider():
     assert isinstance(Storage, provider.LocalStorageProvider)
     Storage = provider.get_storage_provider("s3")
     assert isinstance(Storage, provider.S3StorageProvider)
+    Storage = provider.get_storage_provider("gcs")
+    assert isinstance(Storage, provider.GCSStorageProvider)
     with pytest.raises(RuntimeError):
         provider.get_storage_provider("invalid")
 
@@ -42,6 +47,7 @@ def test_class_instantiation():
         Test()
     provider.LocalStorageProvider()
     provider.S3StorageProvider()
+    provider.GCSStorageProvider()
 
 
 class TestLocalStorageProvider:
@@ -171,3 +177,105 @@ class TestS3StorageProvider:
         self.Storage.delete_all_files()
         assert not (upload_dir / self.filename).exists()
         assert not (upload_dir / self.filename_extra).exists()
+
+class TestGCSStorageProvider:
+    Storage = provider.GCSStorageProvider()
+    Storage.bucket_name = "my-bucket"
+    file_content = b"test content"
+    filename = "test.txt"
+    filename_extra = "test_exyta.txt"
+    file_bytesio_empty = io.BytesIO()
+
+    @pytest.fixture
+    def setup(self):
+        host, port = "localhost", 9023
+
+        server = create_server(host, port, in_memory=True)
+        server.start()
+        os.environ["STORAGE_EMULATOR_HOST"] = f"http://{host}:{port}"
+
+        gcs_client = storage.Client()
+        bucket = gcs_client.bucket(self.Storage.bucket_name)
+        bucket.create()
+        yield gcs_client, bucket
+        bucket.delete(force=True)
+        server.stop()
+    
+    def test_upload_file(self, monkeypatch, tmp_path, setup):
+        upload_dir = mock_upload_dir(monkeypatch, tmp_path)
+        # test error if bucket does not exist
+        with pytest.raises(Exception):
+            self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
+        # creates bucket and test upload_file method, downloads the file and confirms contents
+        self.Storage.gcs_client, self.Storage.bucket = setup
+        contents, gcs_file_path = self.Storage.upload_file(
+            io.BytesIO(self.file_content), self.filename
+        )
+        object = self.Storage.bucket.blob(self.filename)
+        assert self.file_content == object.download_as_bytes()
+        # local checks
+        assert (upload_dir / self.filename).exists()
+        assert (upload_dir / self.filename).read_bytes() == self.file_content
+        assert contents == self.file_content
+        assert gcs_file_path == "gs://" + self.Storage.bucket_name + "/" + self.filename
+        # test error if file is empty
+        with pytest.raises(ValueError):
+            self.Storage.upload_file(self.file_bytesio_empty, self.filename)
+
+    # def test_get_file(self, monkeypatch, tmp_path):
+    #     upload_dir = mock_upload_dir(monkeypatch, tmp_path)
+    #     self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
+    #     contents, s3_file_path = self.Storage.upload_file(
+    #         io.BytesIO(self.file_content), self.filename
+    #     )
+    #     file_path = self.Storage.get_file(s3_file_path)
+    #     assert file_path == str(upload_dir / self.filename)
+    #     assert (upload_dir / self.filename).exists()
+
+    # def test_delete_file(self, monkeypatch, tmp_path):
+    #     upload_dir = mock_upload_dir(monkeypatch, tmp_path)
+    #     self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
+    #     contents, s3_file_path = self.Storage.upload_file(
+    #         io.BytesIO(self.file_content), self.filename
+    #     )
+    #     assert (upload_dir / self.filename).exists()
+    #     self.Storage.delete_file(s3_file_path)
+    #     assert not (upload_dir / self.filename).exists()
+    #     with pytest.raises(ClientError) as exc:
+    #         self.s3_client.Object(self.Storage.bucket_name, self.filename).load()
+    #     error = exc.value.response["Error"]
+    #     assert error["Code"] == "404"
+    #     assert error["Message"] == "Not Found"
+
+    # def test_delete_all_files(self, monkeypatch, tmp_path):
+    #     upload_dir = mock_upload_dir(monkeypatch, tmp_path)
+    #     # create 2 files
+    #     self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
+    #     self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
+    #     object = self.s3_client.Object(self.Storage.bucket_name, self.filename)
+    #     assert self.file_content == object.get()["Body"].read()
+    #     assert (upload_dir / self.filename).exists()
+    #     assert (upload_dir / self.filename).read_bytes() == self.file_content
+    #     self.Storage.upload_file(io.BytesIO(self.file_content), self.filename_extra)
+    #     object = self.s3_client.Object(self.Storage.bucket_name, self.filename_extra)
+    #     assert self.file_content == object.get()["Body"].read()
+    #     assert (upload_dir / self.filename).exists()
+    #     assert (upload_dir / self.filename).read_bytes() == self.file_content
+
+    #     self.Storage.delete_all_files()
+    #     assert not (upload_dir / self.filename).exists()
+    #     with pytest.raises(ClientError) as exc:
+    #         self.s3_client.Object(self.Storage.bucket_name, self.filename).load()
+    #     error = exc.value.response["Error"]
+    #     assert error["Code"] == "404"
+    #     assert error["Message"] == "Not Found"
+    #     assert not (upload_dir / self.filename_extra).exists()
+    #     with pytest.raises(ClientError) as exc:
+    #         self.s3_client.Object(self.Storage.bucket_name, self.filename_extra).load()
+    #     error = exc.value.response["Error"]
+    #     assert error["Code"] == "404"
+    #     assert error["Message"] == "Not Found"
+
+    #     self.Storage.delete_all_files()
+    #     assert not (upload_dir / self.filename).exists()
+    #     assert not (upload_dir / self.filename_extra).exists()

+ 17 - 16
pyproject.toml

@@ -13,17 +13,17 @@ dependencies = [
 
     "Flask==3.1.0",
     "Flask-Cors==5.0.0",
-
+    
     "python-socketio==5.11.3",
     "python-jose==3.3.0",
     "passlib[bcrypt]==1.7.4",
-
+    
     "requests==2.32.3",
     "aiohttp==3.11.8",
     "async-timeout",
     "aiocache",
     "aiofiles",
-
+    
     "sqlalchemy==2.0.32",
     "alembic==1.14.0",
     "peewee==3.17.8",
@@ -32,33 +32,33 @@ dependencies = [
     "pgvector==0.3.5",
     "PyMySQL==1.1.1",
     "bcrypt==4.2.0",
-
+    
     "pymongo",
     "redis",
     "boto3==1.35.53",
-
+    
     "argon2-cffi==23.1.0",
     "APScheduler==3.10.4",
-
+    
     "openai",
     "anthropic",
     "google-generativeai==0.7.2",
     "tiktoken",
-
+    
     "langchain==0.3.7",
     "langchain-community==0.3.7",
-
+    
     "fake-useragent==1.5.1",
     "chromadb==0.6.2",
     "pymilvus==2.5.0",
     "qdrant-client~=1.12.0",
     "opensearch-py==2.7.1",
-
+    
     "transformers",
     "sentence-transformers==3.3.1",
     "colbert-ai==0.2.21",
     "einops==0.8.0",
-
+    
     "ftfy==6.2.3",
     "pypdf==4.3.1",
     "fpdf2==2.8.2",
@@ -77,25 +77,25 @@ dependencies = [
     "psutil",
     "sentencepiece",
     "soundfile==0.12.1",
-
+    
     "opencv-python-headless==4.10.0.84",
     "rapidocr-onnxruntime==1.3.24",
     "rank-bm25==0.2.2",
-
+    
     "faster-whisper==1.0.3",
-
+    
     "PyJWT[crypto]==2.10.1",
     "authlib==1.3.2",
-
+    
     "black==24.8.0",
     "langfuse==2.44.0",
     "youtube-transcript-api==0.6.3",
     "pytube==15.0.0",
-
+    
     "extract_msg",
     "pydub",
     "duckduckgo-search~=6.3.5",
-
+    
     "docker~=7.1.0",
     "pytest~=8.3.2",
     "pytest-docker~=3.1.1",
@@ -104,6 +104,7 @@ dependencies = [
     "googleapis-common-protos==1.63.2",
     "ldap3==2.9.1",
     "google-cloud-storage==2.19.0",
+    "gcp-storage-emulator>=2024.8.3",
 ]
 readme = "README.md"
 requires-python = ">= 3.11, < 3.13.0a1"

+ 50 - 0
uv.lock

@@ -21,6 +21,9 @@ resolution-markers = [
     "python_full_version < '3.12' and platform_system == 'Darwin'",
     "python_full_version < '3.12' and platform_system == 'Darwin'",
     "python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_system == 'Darwin'",
+    "python_full_version < '3.12' and platform_system == 'Darwin'",
+    "python_full_version < '3.12' and platform_system == 'Darwin'",
+    "python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_system == 'Darwin'",
     "python_full_version < '3.13' and platform_system == 'Darwin'",
     "python_full_version >= '3.13' and platform_system == 'Darwin'",
     "python_full_version >= '3.13' and platform_system == 'Darwin'",
@@ -28,6 +31,10 @@ resolution-markers = [
     "python_full_version >= '3.13' and platform_system == 'Darwin'",
     "python_full_version >= '3.13' and platform_system == 'Darwin'",
     "python_full_version >= '3.13' and platform_system == 'Darwin'",
+    "python_full_version >= '3.13' and platform_system == 'Darwin'",
+    "python_full_version < '3.12' and platform_machine == 'aarch64' and platform_system == 'Linux'",
+    "python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine == 'aarch64' and platform_system == 'Linux'",
+    "python_full_version < '3.12' and platform_machine == 'aarch64' and platform_system == 'Linux'",
     "python_full_version < '3.12' and platform_machine == 'aarch64' and platform_system == 'Linux'",
     "python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine == 'aarch64' and platform_system == 'Linux'",
     "python_full_version < '3.12' and platform_machine == 'aarch64' and platform_system == 'Linux'",
@@ -52,6 +59,10 @@ resolution-markers = [
     "python_full_version >= '3.13' and platform_machine == 'aarch64' and platform_system == 'Linux'",
     "python_full_version >= '3.13' and platform_machine == 'aarch64' and platform_system == 'Linux'",
     "python_full_version >= '3.13' and platform_machine == 'aarch64' and platform_system == 'Linux'",
+    "python_full_version >= '3.13' and platform_machine == 'aarch64' and platform_system == 'Linux'",
+    "(python_full_version < '3.12' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version < '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')",
+    "(python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_system != 'Darwin' and platform_system != 'Linux')",
+    "(python_full_version < '3.12' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version < '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')",
     "(python_full_version < '3.12' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version < '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')",
     "(python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_system != 'Darwin' and platform_system != 'Linux')",
     "(python_full_version < '3.12' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version < '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')",
@@ -76,6 +87,7 @@ resolution-markers = [
     "(python_full_version >= '3.13' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.13' and platform_system != 'Darwin' and platform_system != 'Linux')",
     "(python_full_version >= '3.13' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.13' and platform_system != 'Darwin' and platform_system != 'Linux')",
     "(python_full_version >= '3.13' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.13' and platform_system != 'Darwin' and platform_system != 'Linux')",
+    "(python_full_version >= '3.13' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.13' and platform_system != 'Darwin' and platform_system != 'Linux')",
 ]
 
 [[package]]
@@ -234,6 +246,15 @@ wheels = [
     { url = "https://files.pythonhosted.org/packages/46/eb/e7f063ad1fec6b3178a3cd82d1a3c4de82cccf283fc42746168188e1cdd5/anyio-4.8.0-py3-none-any.whl", hash = "sha256:b5011f270ab5eb0abf13385f851315585cc37ef330dd88e27ec3d34d651fd47a", size = 96041 },
 ]
 
+[[package]]
+name = "appdirs"
+version = "1.4.4"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/d7/d8/05696357e0311f5b5c316d7b95f46c669dd9c15aaeecbb48c7d0aeb88c40/appdirs-1.4.4.tar.gz", hash = "sha256:7d5d0167b2b1ba821647616af46a749d1c653740dd0d2415100fe26e27afdf41", size = 13470 }
+wheels = [
+    { url = "https://files.pythonhosted.org/packages/3b/00/2344469e2084fb287c2e0b57b72910309874c3245463acd6cf5e3db69324/appdirs-1.4.4-py2.py3-none-any.whl", hash = "sha256:a841dacd6b99318a741b166adb07e19ee71a274450e68237b4650ca1055ab128", size = 9566 },
+]
+
 [[package]]
 name = "apscheduler"
 version = "3.10.4"
@@ -1276,6 +1297,20 @@ wheels = [
     { url = "https://files.pythonhosted.org/packages/c6/c8/a5be5b7550c10858fcf9b0ea054baccab474da77d37f1e828ce043a3a5d4/frozenlist-1.5.0-py3-none-any.whl", hash = "sha256:d994863bba198a4a518b467bb971c56e1db3f180a25c6cf7bb1949c267f748c3", size = 11901 },
 ]
 
+[[package]]
+name = "fs"
+version = "2.4.16"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+    { name = "appdirs" },
+    { name = "setuptools" },
+    { name = "six" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/5d/a9/af5bfd5a92592c16cdae5c04f68187a309be8a146b528eac3c6e30edbad2/fs-2.4.16.tar.gz", hash = "sha256:ae97c7d51213f4b70b6a958292530289090de3a7e15841e108fbe144f069d313", size = 187441 }
+wheels = [
+    { url = "https://files.pythonhosted.org/packages/b9/5c/a3d95dc1ec6cdeb032d789b552ecc76effa3557ea9186e1566df6aac18df/fs-2.4.16-py2.py3-none-any.whl", hash = "sha256:660064febbccda264ae0b6bace80a8d1be9e089e0a5eb2427b7d517f9a91545c", size = 135261 },
+]
+
 [[package]]
 name = "fsspec"
 version = "2024.9.0"
@@ -1302,6 +1337,19 @@ wheels = [
     { url = "https://files.pythonhosted.org/packages/ed/46/14d230ad057048aea7ccd2f96a80905830866d281ea90a6662a825490659/ftfy-6.2.3-py3-none-any.whl", hash = "sha256:f15761b023f3061a66207d33f0c0149ad40a8319fd16da91796363e2c049fdf8", size = 43011 },
 ]
 
+[[package]]
+name = "gcp-storage-emulator"
+version = "2024.8.3"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+    { name = "fs" },
+    { name = "google-crc32c" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/17/c2/a0b0e1e54fdd9453603d90939faf652e2488b617c8752edc4ebcd89f1686/gcp_storage_emulator-2024.8.3.tar.gz", hash = "sha256:e5d45e5c23a0344c1c4c44b8f8c36f7e8975ca1fcc5134cab608b96ddccd9225", size = 24928 }
+wheels = [
+    { url = "https://files.pythonhosted.org/packages/53/bf/b6c717bd7a5b59244388057d36789e72c28bfaa5e51f3a494563a4e3028e/gcp_storage_emulator-2024.8.3-py3-none-any.whl", hash = "sha256:1dc4ea56a0caf50fc6092898b9461d08b494824bfbbcca168e2af5da89c053ce", size = 19385 },
+]
+
 [[package]]
 name = "git-python"
 version = "1.0.3"
@@ -2769,6 +2817,7 @@ dependencies = [
     { name = "flask-cors" },
     { name = "fpdf2" },
     { name = "ftfy" },
+    { name = "gcp-storage-emulator" },
     { name = "google-cloud-storage" },
     { name = "google-generativeai" },
     { name = "googleapis-common-protos" },
@@ -2853,6 +2902,7 @@ requires-dist = [
     { name = "flask-cors", specifier = "==5.0.0" },
     { name = "fpdf2", specifier = "==2.8.2" },
     { name = "ftfy", specifier = "==6.2.3" },
+    { name = "gcp-storage-emulator", specifier = ">=2024.8.3" },
     { name = "google-cloud-storage", specifier = "==2.19.0" },
     { name = "google-generativeai", specifier = "==0.7.2" },
     { name = "googleapis-common-protos", specifier = "==1.63.2" },