Bläddra i källkod

Merge pull request #9486 from rragundez/store-images

Use DB for generated images
Timothy Jaeryang Baek 2 månader sedan
förälder
incheckning
ab70f1bb50
2 ändrade filer med 62 tillägg och 89 borttagningar
  1. 16 20
      backend/open_webui/routers/files.py
  2. 46 69
      backend/open_webui/routers/images.py

+ 16 - 20
backend/open_webui/routers/files.py

@@ -3,30 +3,22 @@ import os
 import uuid
 from pathlib import Path
 from typing import Optional
-from pydantic import BaseModel
-import mimetypes
 from urllib.parse import quote
 
-from open_webui.storage.provider import Storage
-
+from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
+from fastapi.responses import FileResponse, StreamingResponse
+from open_webui.constants import ERROR_MESSAGES
+from open_webui.env import SRC_LOG_LEVELS
 from open_webui.models.files import (
     FileForm,
     FileModel,
     FileModelResponse,
     Files,
 )
-from open_webui.routers.retrieval import process_file, ProcessFileForm
-
-from open_webui.config import UPLOAD_DIR
-from open_webui.env import SRC_LOG_LEVELS
-from open_webui.constants import ERROR_MESSAGES
-
-
-from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status, Request
-from fastapi.responses import FileResponse, StreamingResponse
-
-
+from open_webui.routers.retrieval import ProcessFileForm, process_file
+from open_webui.storage.provider import Storage
 from open_webui.utils.auth import get_admin_user, get_verified_user
+from pydantic import BaseModel
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -41,7 +33,10 @@ router = APIRouter()
 
 @router.post("/", response_model=FileModelResponse)
 def upload_file(
-    request: Request, file: UploadFile = File(...), user=Depends(get_verified_user)
+    request: Request,
+    file: UploadFile = File(...),
+    user=Depends(get_verified_user),
+    file_metadata: dict = {},
 ):
     log.info(f"file.content_type: {file.content_type}")
     try:
@@ -65,6 +60,7 @@ def upload_file(
                         "name": name,
                         "content_type": file.content_type,
                         "size": len(contents),
+                        "data": file_metadata,
                     },
                 }
             ),
@@ -126,7 +122,7 @@ async def delete_all_files(user=Depends(get_admin_user)):
             Storage.delete_all_files()
         except Exception as e:
             log.exception(e)
-            log.error(f"Error deleting files")
+            log.error("Error deleting files")
             raise HTTPException(
                 status_code=status.HTTP_400_BAD_REQUEST,
                 detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
@@ -248,7 +244,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
                 )
         except Exception as e:
             log.exception(e)
-            log.error(f"Error getting file content")
+            log.error("Error getting file content")
             raise HTTPException(
                 status_code=status.HTTP_400_BAD_REQUEST,
                 detail=ERROR_MESSAGES.DEFAULT("Error getting file content"),
@@ -279,7 +275,7 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
                 )
         except Exception as e:
             log.exception(e)
-            log.error(f"Error getting file content")
+            log.error("Error getting file content")
             raise HTTPException(
                 status_code=status.HTTP_400_BAD_REQUEST,
                 detail=ERROR_MESSAGES.DEFAULT("Error getting file content"),
@@ -355,7 +351,7 @@ async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
                 Storage.delete_file(file.path)
             except Exception as e:
                 log.exception(e)
-                log.error(f"Error deleting files")
+                log.error("Error deleting files")
                 raise HTTPException(
                     status_code=status.HTTP_400_BAD_REQUEST,
                     detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),

+ 46 - 69
backend/open_webui/routers/images.py

@@ -1,32 +1,26 @@
 import asyncio
 import base64
+import io
 import json
 import logging
 import mimetypes
 import re
-import uuid
 from pathlib import Path
 from typing import Optional
 
 import requests
-
-
-from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter
-from fastapi.middleware.cors import CORSMiddleware
-from pydantic import BaseModel
-
-
+from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
 from open_webui.config import CACHE_DIR
 from open_webui.constants import ERROR_MESSAGES
-from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS
-
+from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
+from open_webui.routers.files import upload_file
 from open_webui.utils.auth import get_admin_user, get_verified_user
 from open_webui.utils.images.comfyui import (
     ComfyUIGenerateImageForm,
     ComfyUIWorkflow,
     comfyui_generate_image,
 )
-
+from pydantic import BaseModel
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["IMAGES"])
@@ -271,7 +265,6 @@ async def get_image_config(request: Request, user=Depends(get_admin_user)):
 async def update_image_config(
     request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user)
 ):
-
     set_image_model(request, form_data.MODEL)
 
     pattern = r"^\d+x\d+$"
@@ -383,40 +376,22 @@ class GenerateImageForm(BaseModel):
     negative_prompt: Optional[str] = None
 
 
-def save_b64_image(b64_str):
+def load_b64_image_data(b64_str):
     try:
-        image_id = str(uuid.uuid4())
-
         if "," in b64_str:
             header, encoded = b64_str.split(",", 1)
             mime_type = header.split(";")[0]
-
             img_data = base64.b64decode(encoded)
-            image_format = mimetypes.guess_extension(mime_type)
-
-            image_filename = f"{image_id}{image_format}"
-            file_path = IMAGE_CACHE_DIR / f"{image_filename}"
-            with open(file_path, "wb") as f:
-                f.write(img_data)
-            return image_filename
         else:
-            image_filename = f"{image_id}.png"
-            file_path = IMAGE_CACHE_DIR.joinpath(image_filename)
-
+            mime_type = "image/png"
             img_data = base64.b64decode(b64_str)
-
-            # Write the image data to a file
-            with open(file_path, "wb") as f:
-                f.write(img_data)
-            return image_filename
-
+        return img_data, mime_type
     except Exception as e:
-        log.exception(f"Error saving image: {e}")
+        log.exception(f"Error loading image data: {e}")
         return None
 
 
-def save_url_image(url, headers=None):
-    image_id = str(uuid.uuid4())
+def load_url_image_data(url, headers=None):
     try:
         if headers:
             r = requests.get(url, headers=headers)
@@ -426,18 +401,7 @@ def save_url_image(url, headers=None):
         r.raise_for_status()
         if r.headers["content-type"].split("/")[0] == "image":
             mime_type = r.headers["content-type"]
-            image_format = mimetypes.guess_extension(mime_type)
-
-            if not image_format:
-                raise ValueError("Could not determine image type from MIME type")
-
-            image_filename = f"{image_id}{image_format}"
-
-            file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
-            with open(file_path, "wb") as image_file:
-                for chunk in r.iter_content(chunk_size=8192):
-                    image_file.write(chunk)
-            return image_filename
+            return r.content, mime_type
         else:
             log.error("Url does not point to an image.")
             return None
@@ -447,6 +411,20 @@ def save_url_image(url, headers=None):
         return None
 
 
+def upload_image(request, image_metadata, image_data, content_type, user):
+    image_format = mimetypes.guess_extension(content_type)
+    file = UploadFile(
+        file=io.BytesIO(image_data),
+        filename=f"generated{image_format}",  # will be converted to a unique ID on upload_file
+        headers={
+            "content-type": content_type,
+        },
+    )
+    file_item = upload_file(request, file, user, file_metadata=image_metadata)
+    url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
+    return url
+
+
 @router.post("/generations")
 async def image_generations(
     request: Request,
@@ -500,13 +478,9 @@ async def image_generations(
             images = []
 
             for image in res["data"]:
-                image_filename = save_b64_image(image["b64_json"])
-                images.append({"url": f"/cache/image/generations/{image_filename}"})
-                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
-
-                with open(file_body_path, "w") as f:
-                    json.dump(data, f)
-
+                image_data, content_type = load_b64_image_data(image["b64_json"])
+                url = upload_image(request, data, image_data, content_type, user)
+                images.append({"url": url})
             return images
 
         elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
@@ -552,14 +526,15 @@ async def image_generations(
                         "Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
                     }
 
-                image_filename = save_url_image(image["url"], headers)
-                images.append({"url": f"/cache/image/generations/{image_filename}"})
-                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
-
-                with open(file_body_path, "w") as f:
-                    json.dump(form_data.model_dump(exclude_none=True), f)
-
-            log.debug(f"images: {images}")
+                image_data, content_type = load_url_image_data(image["url"], headers)
+                url = upload_image(
+                    request,
+                    form_data.model_dump(exclude_none=True),
+                    image_data,
+                    content_type,
+                    user,
+                )
+                images.append({"url": url})
             return images
         elif (
             request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
@@ -604,13 +579,15 @@ async def image_generations(
             images = []
 
             for image in res["images"]:
-                image_filename = save_b64_image(image)
-                images.append({"url": f"/cache/image/generations/{image_filename}"})
-                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
-
-                with open(file_body_path, "w") as f:
-                    json.dump({**data, "info": res["info"]}, f)
-
+                image_data, content_type = load_b64_image_data(image)
+                url = upload_image(
+                    request,
+                    {**data, "info": res["info"]},
+                    image_data,
+                    content_type,
+                    user,
+                )
+                images.append({"url": url})
             return images
     except Exception as e:
         error = e