|
@@ -24,6 +24,7 @@ from utils.misc import calculate_sha256
|
|
|
from typing import Optional
|
|
|
from pydantic import BaseModel
|
|
|
from pathlib import Path
|
|
|
+import mimetypes
|
|
|
import uuid
|
|
|
import base64
|
|
|
import json
|
|
@@ -315,38 +316,50 @@ class GenerateImageForm(BaseModel):
|
|
|
|
|
|
|
|
|
def save_b64_image(b64_str):
|
|
|
- image_id = str(uuid.uuid4())
|
|
|
- file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
|
|
|
-
|
|
|
try:
|
|
|
- # Split the base64 string to get the actual image data
|
|
|
- img_data = base64.b64decode(b64_str)
|
|
|
+ header, encoded = b64_str.split(",", 1)
|
|
|
+ mime_type = header.split(";")[0]
|
|
|
+
|
|
|
+ img_data = base64.b64decode(encoded)
|
|
|
|
|
|
- # Write the image data to a file
|
|
|
+ image_id = str(uuid.uuid4())
|
|
|
+ image_format = mimetypes.guess_extension(mime_type)
|
|
|
+
|
|
|
+ image_filename = f"{image_id}{image_format}"
|
|
|
+ file_path = IMAGE_CACHE_DIR / f"{image_filename}"
|
|
|
with open(file_path, "wb") as f:
|
|
|
f.write(img_data)
|
|
|
-
|
|
|
- return image_id
|
|
|
+ return image_filename
|
|
|
except Exception as e:
|
|
|
- log.error(f"Error saving image: {e}")
|
|
|
+ log.exception(f"Error saving image: {e}")
|
|
|
return None
|
|
|
|
|
|
|
|
|
def save_url_image(url):
|
|
|
image_id = str(uuid.uuid4())
|
|
|
- file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
|
|
|
-
|
|
|
try:
|
|
|
r = requests.get(url)
|
|
|
r.raise_for_status()
|
|
|
+ if r.headers["content-type"].split("/")[0] == "image":
|
|
|
+
|
|
|
+ mime_type = r.headers["content-type"]
|
|
|
+ image_format = mimetypes.guess_extension(mime_type)
|
|
|
+
|
|
|
+ if not image_format:
|
|
|
+ raise ValueError("Could not determine image type from MIME type")
|
|
|
|
|
|
- with open(file_path, "wb") as image_file:
|
|
|
- image_file.write(r.content)
|
|
|
+ file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}{image_format}")
|
|
|
+ with open(file_path, "wb") as image_file:
|
|
|
+ for chunk in r.iter_content(chunk_size=8192):
|
|
|
+ image_file.write(chunk)
|
|
|
+ return image_id, image_format
|
|
|
+ else:
|
|
|
+ log.error(f"Url does not point to an image.")
|
|
|
+ return None, None
|
|
|
|
|
|
- return image_id
|
|
|
except Exception as e:
|
|
|
log.exception(f"Error saving image: {e}")
|
|
|
- return None
|
|
|
+ return None, None
|
|
|
|
|
|
|
|
|
@app.post("/generations")
|
|
@@ -385,8 +398,8 @@ def generate_image(
|
|
|
images = []
|
|
|
|
|
|
for image in res["data"]:
|
|
|
- image_id = save_b64_image(image["b64_json"])
|
|
|
- images.append({"url": f"/cache/image/generations/{image_id}.png"})
|
|
|
+ image_filename = save_b64_image(image["b64_json"])
|
|
|
+ images.append({"url": f"/cache/image/generations/{image_filename}"})
|
|
|
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
|
|
|
|
|
|
with open(file_body_path, "w") as f:
|
|
@@ -422,8 +435,10 @@ def generate_image(
|
|
|
images = []
|
|
|
|
|
|
for image in res["data"]:
|
|
|
- image_id = save_url_image(image["url"])
|
|
|
- images.append({"url": f"/cache/image/generations/{image_id}.png"})
|
|
|
+ image_id, image_format = save_url_image(image["url"])
|
|
|
+ images.append(
|
|
|
+ {"url": f"/cache/image/generations/{image_id}{image_format}"}
|
|
|
+ )
|
|
|
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
|
|
|
|
|
|
with open(file_body_path, "w") as f:
|
|
@@ -460,8 +475,8 @@ def generate_image(
|
|
|
images = []
|
|
|
|
|
|
for image in res["images"]:
|
|
|
- image_id = save_b64_image(image)
|
|
|
- images.append({"url": f"/cache/image/generations/{image_id}.png"})
|
|
|
+ image_filename = save_b64_image(image)
|
|
|
+ images.append({"url": f"/cache/image/generations/{image_filename}"})
|
|
|
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
|
|
|
|
|
|
with open(file_body_path, "w") as f:
|