|
@@ -7,26 +7,21 @@ import re
|
|
|
import uuid
|
|
|
from pathlib import Path
|
|
|
from typing import Optional
|
|
|
+import io
|
|
|
|
|
|
import requests
|
|
|
-
|
|
|
-
|
|
|
-from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter
|
|
|
-from fastapi.middleware.cors import CORSMiddleware
|
|
|
-from pydantic import BaseModel
|
|
|
-
|
|
|
-
|
|
|
+from fastapi import APIRouter, Depends, UploadFile, HTTPException, Request
|
|
|
from open_webui.config import CACHE_DIR
|
|
|
from open_webui.constants import ERROR_MESSAGES
|
|
|
-from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS
|
|
|
-
|
|
|
+from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
|
|
|
+from open_webui.routers.files import upload_file
|
|
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
|
|
from open_webui.utils.images.comfyui import (
|
|
|
ComfyUIGenerateImageForm,
|
|
|
ComfyUIWorkflow,
|
|
|
comfyui_generate_image,
|
|
|
)
|
|
|
-
|
|
|
+from pydantic import BaseModel
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
|
|
@@ -39,7 +34,7 @@ router = APIRouter()
|
|
|
|
|
|
|
|
|
@router.get("/config")
|
|
|
-async def get_config(request: Request, user=Depends(get_admin_user)):
|
|
|
+async def get_def(request: Request, user=Depends(get_admin_user)):
|
|
|
return {
|
|
|
"enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
|
|
|
"engine": request.app.state.config.IMAGE_GENERATION_ENGINE,
|
|
@@ -271,7 +266,6 @@ async def get_image_config(request: Request, user=Depends(get_admin_user)):
|
|
|
async def update_image_config(
|
|
|
request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user)
|
|
|
):
|
|
|
-
|
|
|
set_image_model(request, form_data.MODEL)
|
|
|
|
|
|
pattern = r"^\d+x\d+$"
|
|
@@ -383,35 +377,18 @@ class GenerateImageForm(BaseModel):
|
|
|
negative_prompt: Optional[str] = None
|
|
|
|
|
|
|
|
|
-def save_b64_image(b64_str):
|
|
|
+def load_b64_image_data(b64_str):
|
|
|
try:
|
|
|
- image_id = str(uuid.uuid4())
|
|
|
-
|
|
|
if "," in b64_str:
|
|
|
header, encoded = b64_str.split(",", 1)
|
|
|
mime_type = header.split(";")[0]
|
|
|
-
|
|
|
img_data = base64.b64decode(encoded)
|
|
|
- image_format = mimetypes.guess_extension(mime_type)
|
|
|
-
|
|
|
- image_filename = f"{image_id}{image_format}"
|
|
|
- file_path = IMAGE_CACHE_DIR / f"{image_filename}"
|
|
|
- with open(file_path, "wb") as f:
|
|
|
- f.write(img_data)
|
|
|
- return image_filename
|
|
|
else:
|
|
|
- image_filename = f"{image_id}.png"
|
|
|
- file_path = IMAGE_CACHE_DIR.joinpath(image_filename)
|
|
|
-
|
|
|
+ mime_type = "image/png"
|
|
|
img_data = base64.b64decode(b64_str)
|
|
|
-
|
|
|
- # Write the image data to a file
|
|
|
- with open(file_path, "wb") as f:
|
|
|
- f.write(img_data)
|
|
|
- return image_filename
|
|
|
-
|
|
|
+ return img_data, mime_type
|
|
|
except Exception as e:
|
|
|
- log.exception(f"Error saving image: {e}")
|
|
|
+ log.exception(f"Error loading image data: {e}")
|
|
|
return None
|
|
|
|
|
|
|
|
@@ -500,13 +477,17 @@ async def image_generations(
|
|
|
images = []
|
|
|
|
|
|
for image in res["data"]:
|
|
|
- image_filename = save_b64_image(image["b64_json"])
|
|
|
- images.append({"url": f"/cache/image/generations/{image_filename}"})
|
|
|
- file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
|
|
|
-
|
|
|
- with open(file_body_path, "w") as f:
|
|
|
- json.dump(data, f)
|
|
|
-
|
|
|
+ image_data, content_type = load_b64_image_data(image["b64_json"])
|
|
|
+ file = UploadFile(
|
|
|
+ file=io.BytesIO(image_data),
|
|
|
+ filename="image", # will be converted to a unique ID on upload_file
|
|
|
+ headers={
|
|
|
+ "content-type": content_type,
|
|
|
+ },
|
|
|
+ )
|
|
|
+ file_item = upload_file(request, file, user)
|
|
|
+ url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
|
|
|
+ images.append({"url": url})
|
|
|
return images
|
|
|
|
|
|
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
|
@@ -618,4 +599,4 @@ async def image_generations(
|
|
|
data = r.json()
|
|
|
if "error" in data:
|
|
|
error = data["error"]["message"]
|
|
|
- raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))
|
|
|
+ raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))
|