Browse Source

Enable usage of the DB to store generated images

Rodrigo Agundez 3 months ago
parent
commit
159578dfd4
1 changed files with 22 additions and 41 deletions
  1. 22 41
      backend/open_webui/routers/images.py

+ 22 - 41
backend/open_webui/routers/images.py

@@ -7,26 +7,21 @@ import re
 import uuid
 from pathlib import Path
 from typing import Optional
+import io
 
 import requests
-
-
-from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter
-from fastapi.middleware.cors import CORSMiddleware
-from pydantic import BaseModel
-
-
+from fastapi import APIRouter, Depends, UploadFile, HTTPException, Request
 from open_webui.config import CACHE_DIR
 from open_webui.constants import ERROR_MESSAGES
-from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS
-
+from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
+from open_webui.routers.files import upload_file
 from open_webui.utils.auth import get_admin_user, get_verified_user
 from open_webui.utils.images.comfyui import (
     ComfyUIGenerateImageForm,
     ComfyUIWorkflow,
     comfyui_generate_image,
 )
-
+from pydantic import BaseModel
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["IMAGES"])
@@ -39,7 +34,7 @@ router = APIRouter()
 
 
 @router.get("/config")
-async def get_config(request: Request, user=Depends(get_admin_user)):
+async def get_def(request: Request, user=Depends(get_admin_user)):
     return {
         "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
         "engine": request.app.state.config.IMAGE_GENERATION_ENGINE,
@@ -271,7 +266,6 @@ async def get_image_config(request: Request, user=Depends(get_admin_user)):
 async def update_image_config(
     request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user)
 ):
-
     set_image_model(request, form_data.MODEL)
 
     pattern = r"^\d+x\d+$"
@@ -383,35 +377,18 @@ class GenerateImageForm(BaseModel):
     negative_prompt: Optional[str] = None
 
 
-def save_b64_image(b64_str):
+def load_b64_image_data(b64_str):
     try:
-        image_id = str(uuid.uuid4())
-
         if "," in b64_str:
             header, encoded = b64_str.split(",", 1)
             mime_type = header.split(";")[0]
-
             img_data = base64.b64decode(encoded)
-            image_format = mimetypes.guess_extension(mime_type)
-
-            image_filename = f"{image_id}{image_format}"
-            file_path = IMAGE_CACHE_DIR / f"{image_filename}"
-            with open(file_path, "wb") as f:
-                f.write(img_data)
-            return image_filename
         else:
-            image_filename = f"{image_id}.png"
-            file_path = IMAGE_CACHE_DIR.joinpath(image_filename)
-
+            mime_type = "image/png"
             img_data = base64.b64decode(b64_str)
-
-            # Write the image data to a file
-            with open(file_path, "wb") as f:
-                f.write(img_data)
-            return image_filename
-
+        return img_data, mime_type
     except Exception as e:
-        log.exception(f"Error saving image: {e}")
+        log.exception(f"Error loading image data: {e}")
         return None
 
 
@@ -500,13 +477,17 @@ async def image_generations(
             images = []
 
             for image in res["data"]:
-                image_filename = save_b64_image(image["b64_json"])
-                images.append({"url": f"/cache/image/generations/{image_filename}"})
-                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
-
-                with open(file_body_path, "w") as f:
-                    json.dump(data, f)
-
+                image_data, content_type = load_b64_image_data(image["b64_json"])
+                file = UploadFile(
+                    file=io.BytesIO(image_data),
+                    filename="image",  # will be converted to a unique ID on upload_file
+                    headers={
+                        "content-type": content_type,
+                    },
+                )
+                file_item = upload_file(request, file, user)
+                url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
+                images.append({"url": url})
             return images
 
         elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
@@ -618,4 +599,4 @@ async def image_generations(
             data = r.json()
             if "error" in data:
                 error = data["error"]["message"]
-        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))
+        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))