Prechádzať zdrojové kódy

Merge pull request #1815 from Yanyutin753/new-dev

✨ expend the image format type after the file is downloaded
Timothy Jaeryang Baek 1 rok pred
rodič
commit
de62153d49
1 zmenil súbory, kde vykonal 36 pridanie a 21 odobranie
  1. 36 21
      backend/apps/images/main.py

+ 36 - 21
backend/apps/images/main.py

@@ -24,6 +24,7 @@ from utils.misc import calculate_sha256
 from typing import Optional
 from typing import Optional
 from pydantic import BaseModel
 from pydantic import BaseModel
 from pathlib import Path
 from pathlib import Path
+import mimetypes
 import uuid
 import uuid
 import base64
 import base64
 import json
 import json
@@ -315,38 +316,50 @@ class GenerateImageForm(BaseModel):
 
 
 
 
 def save_b64_image(b64_str):
 def save_b64_image(b64_str):
-    image_id = str(uuid.uuid4())
-    file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
-
     try:
     try:
-        # Split the base64 string to get the actual image data
-        img_data = base64.b64decode(b64_str)
+        header, encoded = b64_str.split(",", 1)
+        mime_type = header.split(";")[0]
+
+        img_data = base64.b64decode(encoded)
 
 
-        # Write the image data to a file
+        image_id = str(uuid.uuid4())
+        image_format = mimetypes.guess_extension(mime_type)
+
+        image_filename = f"{image_id}{image_format}"
+        file_path = IMAGE_CACHE_DIR / f"{image_filename}"
         with open(file_path, "wb") as f:
         with open(file_path, "wb") as f:
             f.write(img_data)
             f.write(img_data)
-
-        return image_id
+        return image_filename
     except Exception as e:
     except Exception as e:
-        log.error(f"Error saving image: {e}")
+        log.exception(f"Error saving image: {e}")
         return None
         return None
 
 
 
 
 def save_url_image(url):
 def save_url_image(url):
     image_id = str(uuid.uuid4())
     image_id = str(uuid.uuid4())
-    file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
-
     try:
     try:
         r = requests.get(url)
         r = requests.get(url)
         r.raise_for_status()
         r.raise_for_status()
+        if r.headers["content-type"].split("/")[0] == "image":
+
+            mime_type = r.headers["content-type"]
+            image_format = mimetypes.guess_extension(mime_type)
+
+            if not image_format:
+                raise ValueError("Could not determine image type from MIME type")
 
 
-        with open(file_path, "wb") as image_file:
-            image_file.write(r.content)
+            file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}{image_format}")
+            with open(file_path, "wb") as image_file:
+                for chunk in r.iter_content(chunk_size=8192):
+                    image_file.write(chunk)
+            return image_id, image_format
+        else:
+            log.error(f"Url does not point to an image.")
+            return None, None
 
 
-        return image_id
     except Exception as e:
     except Exception as e:
         log.exception(f"Error saving image: {e}")
         log.exception(f"Error saving image: {e}")
-        return None
+        return None, None
 
 
 
 
 @app.post("/generations")
 @app.post("/generations")
@@ -385,8 +398,8 @@ def generate_image(
             images = []
             images = []
 
 
             for image in res["data"]:
             for image in res["data"]:
-                image_id = save_b64_image(image["b64_json"])
-                images.append({"url": f"/cache/image/generations/{image_id}.png"})
+                image_filename = save_b64_image(image["b64_json"])
+                images.append({"url": f"/cache/image/generations/{image_filename}"})
                 file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
                 file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
 
 
                 with open(file_body_path, "w") as f:
                 with open(file_body_path, "w") as f:
@@ -422,8 +435,10 @@ def generate_image(
             images = []
             images = []
 
 
             for image in res["data"]:
             for image in res["data"]:
-                image_id = save_url_image(image["url"])
-                images.append({"url": f"/cache/image/generations/{image_id}.png"})
+                image_id, image_format = save_url_image(image["url"])
+                images.append(
+                    {"url": f"/cache/image/generations/{image_id}{image_format}"}
+                )
                 file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
                 file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
 
 
                 with open(file_body_path, "w") as f:
                 with open(file_body_path, "w") as f:
@@ -460,8 +475,8 @@ def generate_image(
             images = []
             images = []
 
 
             for image in res["images"]:
             for image in res["images"]:
-                image_id = save_b64_image(image)
-                images.append({"url": f"/cache/image/generations/{image_id}.png"})
+                image_filename = save_b64_image(image)
+                images.append({"url": f"/cache/image/generations/{image_filename}"})
                 file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
                 file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
 
 
                 with open(file_body_path, "w") as f:
                 with open(file_body_path, "w") as f: