|
@@ -31,6 +31,9 @@ from open_webui.routers.tasks import (
|
|
|
generate_chat_tags,
|
|
|
)
|
|
|
from open_webui.routers.retrieval import process_web_search, SearchForm
|
|
|
+from open_webui.routers.images import image_generations, GenerateImageForm
|
|
|
+
|
|
|
+
|
|
|
from open_webui.utils.webhook import post_webhook
|
|
|
|
|
|
|
|
@@ -486,6 +489,67 @@ async def chat_web_search_handler(
|
|
|
return form_data
|
|
|
|
|
|
|
|
|
+async def chat_image_generation_handler(
|
|
|
+ request: Request, form_data: dict, extra_params: dict, user
|
|
|
+):
|
|
|
+ __event_emitter__ = extra_params["__event_emitter__"]
|
|
|
+ await __event_emitter__(
|
|
|
+ {
|
|
|
+ "type": "status",
|
|
|
+ "data": {"description": "Generating an image", "done": False},
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ messages = form_data["messages"]
|
|
|
+ user_message = get_last_user_message(messages)
|
|
|
+
|
|
|
+ system_message_content = ""
|
|
|
+
|
|
|
+ try:
|
|
|
+ images = await image_generations(
|
|
|
+ request=request,
|
|
|
+ form_data=GenerateImageForm(**{"prompt": user_message}),
|
|
|
+ user=user,
|
|
|
+ )
|
|
|
+
|
|
|
+ await __event_emitter__(
|
|
|
+ {
|
|
|
+ "type": "status",
|
|
|
+ "data": {"description": "Generated an image", "done": True},
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ for image in images:
|
|
|
+ await __event_emitter__(
|
|
|
+ {
|
|
|
+ "type": "message",
|
|
|
+ "data": {"content": f""},
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ system_message_content = "<context>User is shown the generated image, tell the user that the image has been generated</context>"
|
|
|
+ except Exception as e:
|
|
|
+ log.exception(e)
|
|
|
+ await __event_emitter__(
|
|
|
+ {
|
|
|
+ "type": "status",
|
|
|
+ "data": {
|
|
|
+ "description": f"An error occured while generating an image",
|
|
|
+ "done": True,
|
|
|
+ },
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ system_message_content = "<context>Unable to generate an image, tell the user that an error occured</context>"
|
|
|
+
|
|
|
+ if system_message_content:
|
|
|
+ form_data["messages"] = add_or_update_system_message(
|
|
|
+ system_message_content, form_data["messages"]
|
|
|
+ )
|
|
|
+
|
|
|
+ return form_data
|
|
|
+
|
|
|
+
|
|
|
async def chat_completion_files_handler(
|
|
|
request: Request, body: dict, user: UserModel
|
|
|
) -> tuple[dict, dict[str, list]]:
|
|
@@ -640,6 +704,11 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
|
|
request, form_data, extra_params, user
|
|
|
)
|
|
|
|
|
|
+ if "image_generation" in features and features["image_generation"]:
|
|
|
+ form_data = await chat_image_generation_handler(
|
|
|
+ request, form_data, extra_params, user
|
|
|
+ )
|
|
|
+
|
|
|
try:
|
|
|
form_data, flags = await chat_completion_filter_functions_handler(
|
|
|
request, form_data, model, extra_params
|