Browse Source

Add functionality in other image generation types

Rodrigo Agundez 3 months ago
parent
commit
8d43fdadc1
1 changed files with 40 additions and 24 deletions
  1. 40 24
      backend/open_webui/routers/images.py

+ 40 - 24
backend/open_webui/routers/images.py

@@ -1,5 +1,6 @@
 import asyncio
 import base64
+import io
 import json
 import logging
 import mimetypes
@@ -7,10 +8,9 @@ import re
 import uuid
 from pathlib import Path
 from typing import Optional
-import io
 
 import requests
-from fastapi import APIRouter, Depends, UploadFile, HTTPException, Request
+from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
 from open_webui.config import CACHE_DIR
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
@@ -392,8 +392,7 @@ def load_b64_image_data(b64_str):
         return None
 
 
-def save_url_image(url, headers=None):
-    image_id = str(uuid.uuid4())
+def load_url_image_data(url, headers=None):
     try:
         if headers:
             r = requests.get(url, headers=headers)
@@ -403,18 +402,7 @@ def save_url_image(url, headers=None):
         r.raise_for_status()
         if r.headers["content-type"].split("/")[0] == "image":
             mime_type = r.headers["content-type"]
-            image_format = mimetypes.guess_extension(mime_type)
-
-            if not image_format:
-                raise ValueError("Could not determine image type from MIME type")
-
-            image_filename = f"{image_id}{image_format}"
-
-            file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
-            with open(file_path, "wb") as image_file:
-                for chunk in r.iter_content(chunk_size=8192):
-                    image_file.write(chunk)
-            return image_filename
+            return r.content, mime_type
         else:
             log.error("Url does not point to an image.")
             return None
@@ -486,8 +474,14 @@ async def image_generations(
                     },
                 )
                 file_item = upload_file(request, file, user)
-                url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
+                url = request.app.url_path_for(
+                    "get_file_content_by_id", id=file_item.id
+                )
                 images.append({"url": url})
+                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{file_item.id}.json")
+
+                with open(file_body_path, "w") as f:
+                    json.dump(data, f)
             return images
 
         elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
@@ -533,9 +527,20 @@ async def image_generations(
                         "Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
                     }
 
-                image_filename = save_url_image(image["url"], headers)
-                images.append({"url": f"/cache/image/generations/{image_filename}"})
-                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
+                image_data, content_type = load_url_image_data(image["url"], headers)
+                file = UploadFile(
+                    file=io.BytesIO(image_data),
+                    filename="image",  # will be converted to a unique ID on upload_file
+                    headers={
+                        "content-type": content_type,
+                    },
+                )
+                file_item = upload_file(request, file, user)
+                url = request.app.url_path_for(
+                    "get_file_content_by_id", id=file_item.id
+                )
+                images.append({"url": url})
+                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{file_item.id}.json")
 
                 with open(file_body_path, "w") as f:
                     json.dump(form_data.model_dump(exclude_none=True), f)
@@ -585,9 +590,20 @@ async def image_generations(
             images = []
 
             for image in res["images"]:
-                image_filename = save_b64_image(image)
-                images.append({"url": f"/cache/image/generations/{image_filename}"})
-                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
+                image_data, content_type = load_b64_image_data(image)
+                file = UploadFile(
+                    file=io.BytesIO(image_data),
+                    filename="image",  # will be converted to a unique ID on upload_file
+                    headers={
+                        "content-type": content_type,
+                    },
+                )
+                file_item = upload_file(request, file, user)
+                url = request.app.url_path_for(
+                    "get_file_content_by_id", id=file_item.id
+                )
+                images.append({"url": url})
+                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{file_item.id}.json")
 
                 with open(file_body_path, "w") as f:
                     json.dump({**data, "info": res["info"]}, f)
@@ -599,4 +615,4 @@ async def image_generations(
             data = r.json()
             if "error" in data:
                 error = data["error"]["message"]
-        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))
+        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))