Browse Source

enh: image prompt enhancer

Timothy Jaeryang Baek 3 months ago
parent
commit
0360aa5520

+ 26 - 0
backend/open_webui/config.py

@@ -1055,6 +1055,32 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] }
 {{MESSAGES:END:6}}
 {{MESSAGES:END:6}}
 </chat_history>"""
 </chat_history>"""
 
 
+IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
+    "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE",
+    "task.image.prompt_template",
+    os.environ.get("IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE", ""),
+)
+
+DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = """### Task:
+Generate a detailed prompt for am image generation task based on the given language and context. Describe the image as if you were explaining it to someone who cannot see it. Include relevant details, colors, shapes, and any other important elements.
+
+### Guidelines:
+- Be descriptive and detailed, focusing on the most important aspects of the image.
+- Avoid making assumptions or adding information not present in the image.
+- Use the chat's primary language; default to English if multilingual.
+- If the image is too complex, focus on the most prominent elements.
+
+### Output:
+Strictly return in JSON format:
+{
+    "prompt": "Your detailed description here."
+}
+
+### Chat History:
+<chat_history>
+{{MESSAGES:END:6}}
+</chat_history>"""
+
 ENABLE_TAGS_GENERATION = PersistentConfig(
 ENABLE_TAGS_GENERATION = PersistentConfig(
     "ENABLE_TAGS_GENERATION",
     "ENABLE_TAGS_GENERATION",
     "task.tags.enable",
     "task.tags.enable",

+ 1 - 0
backend/open_webui/constants.py

@@ -113,6 +113,7 @@ class TASKS(str, Enum):
     TAGS_GENERATION = "tags_generation"
     TAGS_GENERATION = "tags_generation"
     EMOJI_GENERATION = "emoji_generation"
     EMOJI_GENERATION = "emoji_generation"
     QUERY_GENERATION = "query_generation"
     QUERY_GENERATION = "query_generation"
+    IMAGE_PROMPT_GENERATION = "image_prompt_generation"
     AUTOCOMPLETE_GENERATION = "autocomplete_generation"
     AUTOCOMPLETE_GENERATION = "autocomplete_generation"
     FUNCTION_CALLING = "function_calling"
     FUNCTION_CALLING = "function_calling"
     MOA_RESPONSE_GENERATION = "moa_response_generation"
     MOA_RESPONSE_GENERATION = "moa_response_generation"

+ 5 - 0
backend/open_webui/main.py

@@ -255,6 +255,7 @@ from open_webui.config import (
     ENABLE_AUTOCOMPLETE_GENERATION,
     ENABLE_AUTOCOMPLETE_GENERATION,
     TITLE_GENERATION_PROMPT_TEMPLATE,
     TITLE_GENERATION_PROMPT_TEMPLATE,
     TAGS_GENERATION_PROMPT_TEMPLATE,
     TAGS_GENERATION_PROMPT_TEMPLATE,
+    IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
     TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
     TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
     QUERY_GENERATION_PROMPT_TEMPLATE,
     QUERY_GENERATION_PROMPT_TEMPLATE,
     AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
     AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
@@ -644,6 +645,10 @@ app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION
 
 
 app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
 app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
 app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE
 app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE
+app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = (
+    IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
+)
+
 app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
 app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
     TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
     TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
 )
 )

+ 65 - 0
backend/open_webui/routers/tasks.py

@@ -9,6 +9,7 @@ from open_webui.utils.chat import generate_chat_completion
 from open_webui.utils.task import (
 from open_webui.utils.task import (
     title_generation_template,
     title_generation_template,
     query_generation_template,
     query_generation_template,
+    image_prompt_generation_template,
     autocomplete_generation_template,
     autocomplete_generation_template,
     tags_generation_template,
     tags_generation_template,
     emoji_generation_template,
     emoji_generation_template,
@@ -23,6 +24,7 @@ from open_webui.utils.task import get_task_model_id
 from open_webui.config import (
 from open_webui.config import (
     DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE,
     DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE,
     DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE,
     DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE,
+    DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
     DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
     DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
     DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
     DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
     DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE,
     DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE,
@@ -50,6 +52,7 @@ async def get_task_config(request: Request, user=Depends(get_verified_user)):
         "TASK_MODEL": request.app.state.config.TASK_MODEL,
         "TASK_MODEL": request.app.state.config.TASK_MODEL,
         "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
         "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
         "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
         "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
+        "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
         "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
         "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
         "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
         "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
         "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
         "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
@@ -65,6 +68,7 @@ class TaskConfigForm(BaseModel):
     TASK_MODEL: Optional[str]
     TASK_MODEL: Optional[str]
     TASK_MODEL_EXTERNAL: Optional[str]
     TASK_MODEL_EXTERNAL: Optional[str]
     TITLE_GENERATION_PROMPT_TEMPLATE: str
     TITLE_GENERATION_PROMPT_TEMPLATE: str
+    IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: str
     ENABLE_AUTOCOMPLETE_GENERATION: bool
     ENABLE_AUTOCOMPLETE_GENERATION: bool
     AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int
     AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int
     TAGS_GENERATION_PROMPT_TEMPLATE: str
     TAGS_GENERATION_PROMPT_TEMPLATE: str
@@ -114,6 +118,7 @@ async def update_task_config(
         "TASK_MODEL": request.app.state.config.TASK_MODEL,
         "TASK_MODEL": request.app.state.config.TASK_MODEL,
         "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
         "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
         "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
         "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
+        "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
         "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
         "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
         "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
         "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
         "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
         "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
@@ -256,6 +261,66 @@ async def generate_chat_tags(
         )
         )
 
 
 
 
+@router.post("/image_prompt/completions")
+async def generate_image_prompt(
+    request: Request, form_data: dict, user=Depends(get_verified_user)
+):
+    models = request.app.state.MODELS
+
+    model_id = form_data["model"]
+    if model_id not in models:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail="Model not found",
+        )
+
+    # Check if the user has a custom task model
+    # If the user has a custom task model, use that model
+    task_model_id = get_task_model_id(
+        model_id,
+        request.app.state.config.TASK_MODEL,
+        request.app.state.config.TASK_MODEL_EXTERNAL,
+        models,
+    )
+
+    log.debug(
+        f"generating image prompt using model {task_model_id} for user {user.email} "
+    )
+
+    if request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE != "":
+        template = request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
+    else:
+        template = DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
+
+    content = image_prompt_generation_template(
+        template,
+        form_data["messages"],
+        user={
+            "name": user.name,
+        },
+    )
+
+    payload = {
+        "model": task_model_id,
+        "messages": [{"role": "user", "content": content}],
+        "stream": False,
+        "metadata": {
+            "task": str(TASKS.IMAGE_PROMPT_GENERATION),
+            "task_body": form_data,
+            "chat_id": form_data.get("chat_id", None),
+        },
+    }
+
+    try:
+        return await generate_chat_completion(request, form_data=payload, user=user)
+    except Exception as e:
+        log.error("Exception occurred", exc_info=True)
+        return JSONResponse(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            content={"detail": "An internal error has occurred."},
+        )
+
+
 @router.post("/queries/completions")
 @router.post("/queries/completions")
 async def generate_queries(
 async def generate_queries(
     request: Request, form_data: dict, user=Depends(get_verified_user)
     request: Request, form_data: dict, user=Depends(get_verified_user)

+ 34 - 1
backend/open_webui/utils/middleware.py

@@ -28,6 +28,7 @@ from open_webui.socket.main import (
 from open_webui.routers.tasks import (
 from open_webui.routers.tasks import (
     generate_queries,
     generate_queries,
     generate_title,
     generate_title,
+    generate_image_prompt,
     generate_chat_tags,
     generate_chat_tags,
 )
 )
 from open_webui.routers.retrieval import process_web_search, SearchForm
 from open_webui.routers.retrieval import process_web_search, SearchForm
@@ -503,12 +504,44 @@ async def chat_image_generation_handler(
     messages = form_data["messages"]
     messages = form_data["messages"]
     user_message = get_last_user_message(messages)
     user_message = get_last_user_message(messages)
 
 
+    prompt = ""
+    negative_prompt = ""
+
+    try:
+        res = await generate_image_prompt(
+            request,
+            {
+                "model": form_data["model"],
+                "messages": messages,
+            },
+            user,
+        )
+
+        response = res["choices"][0]["message"]["content"]
+
+        try:
+            bracket_start = response.find("{")
+            bracket_end = response.rfind("}") + 1
+
+            if bracket_start == -1 or bracket_end == -1:
+                raise Exception("No JSON object found in the response")
+
+            response = response[bracket_start:bracket_end]
+            response = json.loads(response)
+            prompt = response.get("prompt", [])
+        except Exception as e:
+            prompt = user_message
+
+    except Exception as e:
+        log.exception(e)
+        prompt = user_message
+
     system_message_content = ""
     system_message_content = ""
 
 
     try:
     try:
         images = await image_generations(
         images = await image_generations(
             request=request,
             request=request,
-            form_data=GenerateImageForm(**{"prompt": user_message}),
+            form_data=GenerateImageForm(**{"prompt": prompt}),
             user=user,
             user=user,
         )
         )
 
 

+ 18 - 0
backend/open_webui/utils/task.py

@@ -217,6 +217,24 @@ def tags_generation_template(
     return template
     return template
 
 
 
 
+def image_prompt_generation_template(
+    template: str, messages: list[dict], user: Optional[dict] = None
+) -> str:
+    prompt = get_last_user_message(messages)
+    template = replace_prompt_variable(template, prompt)
+    template = replace_messages_variable(template, messages)
+
+    template = prompt_template(
+        template,
+        **(
+            {"user_name": user.get("name"), "user_location": user.get("location")}
+            if user
+            else {}
+        ),
+    )
+    return template
+
+
 def emoji_generation_template(
 def emoji_generation_template(
     template: str, prompt: str, user: Optional[dict] = None
     template: str, prompt: str, user: Optional[dict] = None
 ) -> str:
 ) -> str:

+ 17 - 0
src/lib/components/admin/Settings/Interface.svelte

@@ -24,6 +24,7 @@
 		TASK_MODEL: '',
 		TASK_MODEL: '',
 		TASK_MODEL_EXTERNAL: '',
 		TASK_MODEL_EXTERNAL: '',
 		TITLE_GENERATION_PROMPT_TEMPLATE: '',
 		TITLE_GENERATION_PROMPT_TEMPLATE: '',
+		IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: '',
 		ENABLE_AUTOCOMPLETE_GENERATION: true,
 		ENABLE_AUTOCOMPLETE_GENERATION: true,
 		AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: -1,
 		AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: -1,
 		TAGS_GENERATION_PROMPT_TEMPLATE: '',
 		TAGS_GENERATION_PROMPT_TEMPLATE: '',
@@ -140,6 +141,22 @@
 					</Tooltip>
 					</Tooltip>
 				</div>
 				</div>
 
 
+				<div class="mt-3">
+					<div class=" mb-2.5 text-xs font-medium">{$i18n.t('Image Prompt Generation Prompt')}</div>
+
+					<Tooltip
+						content={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
+						placement="top-start"
+					>
+						<Textarea
+							bind:value={taskConfig.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE}
+							placeholder={$i18n.t(
+								'Leave empty to use the default prompt, or enter a custom prompt'
+							)}
+						/>
+					</Tooltip>
+				</div>
+
 				<hr class=" border-gray-50 dark:border-gray-850 my-3" />
 				<hr class=" border-gray-50 dark:border-gray-850 my-3" />
 
 
 				<div class="my-3 flex w-full items-center justify-between">
 				<div class="my-3 flex w-full items-center justify-between">