Timothy Jaeryang Baek 3 months ago
parent
commit
0425621494

+ 7 - 0
backend/open_webui/config.py

@@ -1650,6 +1650,13 @@ ENABLE_IMAGE_GENERATION = PersistentConfig(
     "image_generation.enable",
     "image_generation.enable",
     os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true",
     os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true",
 )
 )
+
+ENABLE_IMAGE_PROMPT_GENERATION = PersistentConfig(
+    "ENABLE_IMAGE_PROMPT_GENERATION",
+    "image_generation.prompt.enable",
+    os.environ.get("ENABLE_IMAGE_PROMPT_GENERATION", "true").lower() == "true",
+)
+
 AUTOMATIC1111_BASE_URL = PersistentConfig(
 AUTOMATIC1111_BASE_URL = PersistentConfig(
     "AUTOMATIC1111_BASE_URL",
     "AUTOMATIC1111_BASE_URL",
     "image_generation.automatic1111.base_url",
     "image_generation.automatic1111.base_url",

+ 2 - 0
backend/open_webui/main.py

@@ -108,6 +108,7 @@ from open_webui.config import (
     COMFYUI_WORKFLOW,
     COMFYUI_WORKFLOW,
     COMFYUI_WORKFLOW_NODES,
     COMFYUI_WORKFLOW_NODES,
     ENABLE_IMAGE_GENERATION,
     ENABLE_IMAGE_GENERATION,
+    ENABLE_IMAGE_PROMPT_GENERATION,
     IMAGE_GENERATION_ENGINE,
     IMAGE_GENERATION_ENGINE,
     IMAGE_GENERATION_MODEL,
     IMAGE_GENERATION_MODEL,
     IMAGE_SIZE,
     IMAGE_SIZE,
@@ -575,6 +576,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
 
 
 app.state.config.IMAGE_GENERATION_ENGINE = IMAGE_GENERATION_ENGINE
 app.state.config.IMAGE_GENERATION_ENGINE = IMAGE_GENERATION_ENGINE
 app.state.config.ENABLE_IMAGE_GENERATION = ENABLE_IMAGE_GENERATION
 app.state.config.ENABLE_IMAGE_GENERATION = ENABLE_IMAGE_GENERATION
+app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = ENABLE_IMAGE_PROMPT_GENERATION
 
 
 app.state.config.IMAGES_OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
 app.state.config.IMAGES_OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
 app.state.config.IMAGES_OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
 app.state.config.IMAGES_OPENAI_API_KEY = IMAGES_OPENAI_API_KEY

+ 7 - 0
backend/open_webui/routers/images.py

@@ -43,6 +43,7 @@ async def get_config(request: Request, user=Depends(get_admin_user)):
     return {
     return {
         "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
         "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
         "engine": request.app.state.config.IMAGE_GENERATION_ENGINE,
         "engine": request.app.state.config.IMAGE_GENERATION_ENGINE,
+        "prompt_generation": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION,
         "openai": {
         "openai": {
             "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
             "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
             "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
             "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
@@ -86,6 +87,7 @@ class ComfyUIConfigForm(BaseModel):
 class ConfigForm(BaseModel):
 class ConfigForm(BaseModel):
     enabled: bool
     enabled: bool
     engine: str
     engine: str
+    prompt_generation: bool
     openai: OpenAIConfigForm
     openai: OpenAIConfigForm
     automatic1111: Automatic1111ConfigForm
     automatic1111: Automatic1111ConfigForm
     comfyui: ComfyUIConfigForm
     comfyui: ComfyUIConfigForm
@@ -98,6 +100,10 @@ async def update_config(
     request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.engine
     request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.engine
     request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.enabled
     request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.enabled
 
 
+    request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = (
+        form_data.prompt_generation
+    )
+
     request.app.state.config.IMAGES_OPENAI_API_BASE_URL = (
     request.app.state.config.IMAGES_OPENAI_API_BASE_URL = (
         form_data.openai.OPENAI_API_BASE_URL
         form_data.openai.OPENAI_API_BASE_URL
     )
     )
@@ -137,6 +143,7 @@ async def update_config(
     return {
     return {
         "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
         "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
         "engine": request.app.state.config.IMAGE_GENERATION_ENGINE,
         "engine": request.app.state.config.IMAGE_GENERATION_ENGINE,
+        "prompt_generation": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION,
         "openai": {
         "openai": {
             "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
             "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
             "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
             "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,

+ 24 - 23
backend/open_webui/utils/middleware.py

@@ -504,38 +504,39 @@ async def chat_image_generation_handler(
     messages = form_data["messages"]
     messages = form_data["messages"]
     user_message = get_last_user_message(messages)
     user_message = get_last_user_message(messages)
 
 
-    prompt = ""
+    prompt = user_message
     negative_prompt = ""
     negative_prompt = ""
 
 
-    try:
-        res = await generate_image_prompt(
-            request,
-            {
-                "model": form_data["model"],
-                "messages": messages,
-            },
-            user,
-        )
+    if request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION:
+        try:
+            res = await generate_image_prompt(
+                request,
+                {
+                    "model": form_data["model"],
+                    "messages": messages,
+                },
+                user,
+            )
 
 
-        response = res["choices"][0]["message"]["content"]
+            response = res["choices"][0]["message"]["content"]
 
 
-        try:
-            bracket_start = response.find("{")
-            bracket_end = response.rfind("}") + 1
+            try:
+                bracket_start = response.find("{")
+                bracket_end = response.rfind("}") + 1
 
 
-            if bracket_start == -1 or bracket_end == -1:
-                raise Exception("No JSON object found in the response")
+                if bracket_start == -1 or bracket_end == -1:
+                    raise Exception("No JSON object found in the response")
+
+                response = response[bracket_start:bracket_end]
+                response = json.loads(response)
+                prompt = response.get("prompt", [])
+            except Exception as e:
+                prompt = user_message
 
 
-            response = response[bracket_start:bracket_end]
-            response = json.loads(response)
-            prompt = response.get("prompt", [])
         except Exception as e:
         except Exception as e:
+            log.exception(e)
             prompt = user_message
             prompt = user_message
 
 
-    except Exception as e:
-        log.exception(e)
-        prompt = user_message
-
     system_message_content = ""
     system_message_content = ""
 
 
     try:
     try:

+ 11 - 2
src/lib/components/admin/Settings/Images.svelte

@@ -234,7 +234,7 @@
 				<div class=" mb-1 text-sm font-medium">{$i18n.t('Image Settings')}</div>
 				<div class=" mb-1 text-sm font-medium">{$i18n.t('Image Settings')}</div>
 
 
 				<div>
 				<div>
-					<div class=" py-0.5 flex w-full justify-between">
+					<div class=" py-1 flex w-full justify-between">
 						<div class=" self-center text-xs font-medium">
 						<div class=" self-center text-xs font-medium">
 							{$i18n.t('Image Generation (Experimental)')}
 							{$i18n.t('Image Generation (Experimental)')}
 						</div>
 						</div>
@@ -271,7 +271,16 @@
 					</div>
 					</div>
 				</div>
 				</div>
 
 
-				<div class=" py-0.5 flex w-full justify-between">
+				{#if config.enabled}
+					<div class=" py-1 flex w-full justify-between">
+						<div class=" self-center text-xs font-medium">{$i18n.t('Image Prompt Generation')}</div>
+						<div class="px-1">
+							<Switch bind:state={config.prompt_generation} />
+						</div>
+					</div>
+				{/if}
+
+				<div class=" py-1 flex w-full justify-between">
 					<div class=" self-center text-xs font-medium">{$i18n.t('Image Generation Engine')}</div>
 					<div class=" self-center text-xs font-medium">{$i18n.t('Image Generation Engine')}</div>
 					<div class="flex items-center relative">
 					<div class="flex items-center relative">
 						<select
 						<select