Explorar o código

Merge pull request #1275 from open-webui/comfyui

feat: comfyui support
Timothy Jaeryang Baek hai 1 ano
pai
achega
adf9ccb5eb

+ 102 - 12
backend/apps/images/main.py

@@ -18,6 +18,8 @@ from utils.utils import (
     get_current_user,
     get_current_user,
     get_admin_user,
     get_admin_user,
 )
 )
+
+from apps.images.utils.comfyui import ImageGenerationPayload, comfyui_generate_image
 from utils.misc import calculate_sha256
 from utils.misc import calculate_sha256
 from typing import Optional
 from typing import Optional
 from pydantic import BaseModel
 from pydantic import BaseModel
@@ -26,7 +28,7 @@ import uuid
 import base64
 import base64
 import json
 import json
 
 
-from config import CACHE_DIR, AUTOMATIC1111_BASE_URL
+from config import CACHE_DIR, AUTOMATIC1111_BASE_URL, COMFYUI_BASE_URL
 
 
 
 
 IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
 IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
@@ -49,6 +51,8 @@ app.state.MODEL = ""
 
 
 
 
 app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
 app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
+app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
+
 
 
 app.state.IMAGE_SIZE = "512x512"
 app.state.IMAGE_SIZE = "512x512"
 app.state.IMAGE_STEPS = 50
 app.state.IMAGE_STEPS = 50
@@ -71,32 +75,48 @@ async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user
     return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
     return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
 
 
 
 
-class UrlUpdateForm(BaseModel):
-    url: str
+class EngineUrlUpdateForm(BaseModel):
+    AUTOMATIC1111_BASE_URL: Optional[str] = None
+    COMFYUI_BASE_URL: Optional[str] = None
 
 
 
 
 @app.get("/url")
 @app.get("/url")
-async def get_automatic1111_url(user=Depends(get_admin_user)):
-    return {"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL}
+async def get_engine_url(user=Depends(get_admin_user)):
+    return {
+        "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
+        "COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL,
+    }
 
 
 
 
 @app.post("/url/update")
 @app.post("/url/update")
-async def update_automatic1111_url(
-    form_data: UrlUpdateForm, user=Depends(get_admin_user)
+async def update_engine_url(
+    form_data: EngineUrlUpdateForm, user=Depends(get_admin_user)
 ):
 ):
 
 
-    if form_data.url == "":
+    if form_data.AUTOMATIC1111_BASE_URL == None:
         app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
         app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
     else:
     else:
-        url = form_data.url.strip("/")
+        url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
         try:
         try:
             r = requests.head(url)
             r = requests.head(url)
             app.state.AUTOMATIC1111_BASE_URL = url
             app.state.AUTOMATIC1111_BASE_URL = url
         except Exception as e:
         except Exception as e:
             raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
             raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
 
 
+    if form_data.COMFYUI_BASE_URL == None:
+        app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
+    else:
+        url = form_data.COMFYUI_BASE_URL.strip("/")
+
+        try:
+            r = requests.head(url)
+            app.state.COMFYUI_BASE_URL = url
+        except Exception as e:
+            raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
+
     return {
     return {
         "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
         "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
+        "COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL,
         "status": True,
         "status": True,
     }
     }
 
 
@@ -186,6 +206,18 @@ def get_models(user=Depends(get_current_user)):
                 {"id": "dall-e-2", "name": "DALL·E 2"},
                 {"id": "dall-e-2", "name": "DALL·E 2"},
                 {"id": "dall-e-3", "name": "DALL·E 3"},
                 {"id": "dall-e-3", "name": "DALL·E 3"},
             ]
             ]
+        elif app.state.ENGINE == "comfyui":
+
+            r = requests.get(url=f"{app.state.COMFYUI_BASE_URL}/object_info")
+            info = r.json()
+
+            return list(
+                map(
+                    lambda model: {"id": model, "name": model},
+                    info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0],
+                )
+            )
+
         else:
         else:
             r = requests.get(
             r = requests.get(
                 url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
                 url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
@@ -207,6 +239,8 @@ async def get_default_model(user=Depends(get_admin_user)):
     try:
     try:
         if app.state.ENGINE == "openai":
         if app.state.ENGINE == "openai":
             return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"}
             return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"}
+        elif app.state.ENGINE == "comfyui":
+            return {"model": app.state.MODEL if app.state.MODEL else ""}
         else:
         else:
             r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
             r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
             options = r.json()
             options = r.json()
@@ -221,10 +255,12 @@ class UpdateModelForm(BaseModel):
 
 
 
 
 def set_model_handler(model: str):
 def set_model_handler(model: str):
-
     if app.state.ENGINE == "openai":
     if app.state.ENGINE == "openai":
         app.state.MODEL = model
         app.state.MODEL = model
         return app.state.MODEL
         return app.state.MODEL
+    if app.state.ENGINE == "comfyui":
+        app.state.MODEL = model
+        return app.state.MODEL
     else:
     else:
         r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
         r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
         options = r.json()
         options = r.json()
@@ -272,12 +308,31 @@ def save_b64_image(b64_str):
         return None
         return None
 
 
 
 
+def save_url_image(url):
+    image_id = str(uuid.uuid4())
+    file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
+
+    try:
+        r = requests.get(url)
+        r.raise_for_status()
+
+        with open(file_path, "wb") as image_file:
+            image_file.write(r.content)
+
+        return image_id
+    except Exception as e:
+        print(f"Error saving image: {e}")
+        return None
+
+
 @app.post("/generations")
 @app.post("/generations")
 def generate_image(
 def generate_image(
     form_data: GenerateImageForm,
     form_data: GenerateImageForm,
     user=Depends(get_current_user),
     user=Depends(get_current_user),
 ):
 ):
 
 
+    width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
+
     r = None
     r = None
     try:
     try:
         if app.state.ENGINE == "openai":
         if app.state.ENGINE == "openai":
@@ -315,12 +370,47 @@ def generate_image(
 
 
             return images
             return images
 
 
+        elif app.state.ENGINE == "comfyui":
+
+            data = {
+                "prompt": form_data.prompt,
+                "width": width,
+                "height": height,
+                "n": form_data.n,
+            }
+
+            if app.state.IMAGE_STEPS != None:
+                data["steps"] = app.state.IMAGE_STEPS
+
+            if form_data.negative_prompt != None:
+                data["negative_prompt"] = form_data.negative_prompt
+
+            data = ImageGenerationPayload(**data)
+
+            res = comfyui_generate_image(
+                app.state.MODEL,
+                data,
+                user.id,
+                app.state.COMFYUI_BASE_URL,
+            )
+            print(res)
+
+            images = []
+
+            for image in res["data"]:
+                image_id = save_url_image(image["url"])
+                images.append({"url": f"/cache/image/generations/{image_id}.png"})
+                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
+
+                with open(file_body_path, "w") as f:
+                    json.dump(data.model_dump(exclude_none=True), f)
+
+            print(images)
+            return images
         else:
         else:
             if form_data.model:
             if form_data.model:
                 set_model_handler(form_data.model)
                 set_model_handler(form_data.model)
 
 
-            width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
-
             data = {
             data = {
                 "prompt": form_data.prompt,
                 "prompt": form_data.prompt,
                 "batch_size": form_data.n,
                 "batch_size": form_data.n,

+ 228 - 0
backend/apps/images/utils/comfyui.py

@@ -0,0 +1,228 @@
+import websocket  # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
+import uuid
+import json
+import urllib.request
+import urllib.parse
+import random
+
+from pydantic import BaseModel
+
+from typing import Optional
+
+COMFYUI_DEFAULT_PROMPT = """
+{
+  "3": {
+    "inputs": {
+      "seed": 0,
+      "steps": 20,
+      "cfg": 8,
+      "sampler_name": "euler",
+      "scheduler": "normal",
+      "denoise": 1,
+      "model": [
+        "4",
+        0
+      ],
+      "positive": [
+        "6",
+        0
+      ],
+      "negative": [
+        "7",
+        0
+      ],
+      "latent_image": [
+        "5",
+        0
+      ]
+    },
+    "class_type": "KSampler",
+    "_meta": {
+      "title": "KSampler"
+    }
+  },
+  "4": {
+    "inputs": {
+      "ckpt_name": "model.safetensors"
+    },
+    "class_type": "CheckpointLoaderSimple",
+    "_meta": {
+      "title": "Load Checkpoint"
+    }
+  },
+  "5": {
+    "inputs": {
+      "width": 512,
+      "height": 512,
+      "batch_size": 1
+    },
+    "class_type": "EmptyLatentImage",
+    "_meta": {
+      "title": "Empty Latent Image"
+    }
+  },
+  "6": {
+    "inputs": {
+      "text": "Prompt",
+      "clip": [
+        "4",
+        1
+      ]
+    },
+    "class_type": "CLIPTextEncode",
+    "_meta": {
+      "title": "CLIP Text Encode (Prompt)"
+    }
+  },
+  "7": {
+    "inputs": {
+      "text": "Negative Prompt",
+      "clip": [
+        "4",
+        1
+      ]
+    },
+    "class_type": "CLIPTextEncode",
+    "_meta": {
+      "title": "CLIP Text Encode (Prompt)"
+    }
+  },
+  "8": {
+    "inputs": {
+      "samples": [
+        "3",
+        0
+      ],
+      "vae": [
+        "4",
+        2
+      ]
+    },
+    "class_type": "VAEDecode",
+    "_meta": {
+      "title": "VAE Decode"
+    }
+  },
+  "9": {
+    "inputs": {
+      "filename_prefix": "ComfyUI",
+      "images": [
+        "8",
+        0
+      ]
+    },
+    "class_type": "SaveImage",
+    "_meta": {
+      "title": "Save Image"
+    }
+  }
+}
+"""
+
+
+def queue_prompt(prompt, client_id, base_url):
+    print("queue_prompt")
+    p = {"prompt": prompt, "client_id": client_id}
+    data = json.dumps(p).encode("utf-8")
+    req = urllib.request.Request(f"{base_url}/prompt", data=data)
+    return json.loads(urllib.request.urlopen(req).read())
+
+
+def get_image(filename, subfolder, folder_type, base_url):
+    print("get_image")
+    data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
+    url_values = urllib.parse.urlencode(data)
+    with urllib.request.urlopen(f"{base_url}/view?{url_values}") as response:
+        return response.read()
+
+
+def get_image_url(filename, subfolder, folder_type, base_url):
+    print("get_image")
+    data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
+    url_values = urllib.parse.urlencode(data)
+    return f"{base_url}/view?{url_values}"
+
+
+def get_history(prompt_id, base_url):
+    print("get_history")
+    with urllib.request.urlopen(f"{base_url}/history/{prompt_id}") as response:
+        return json.loads(response.read())
+
+
+def get_images(ws, prompt, client_id, base_url):
+    prompt_id = queue_prompt(prompt, client_id, base_url)["prompt_id"]
+    output_images = []
+    while True:
+        out = ws.recv()
+        if isinstance(out, str):
+            message = json.loads(out)
+            if message["type"] == "executing":
+                data = message["data"]
+                if data["node"] is None and data["prompt_id"] == prompt_id:
+                    break  # Execution is done
+        else:
+            continue  # previews are binary data
+
+    history = get_history(prompt_id, base_url)[prompt_id]
+    for o in history["outputs"]:
+        for node_id in history["outputs"]:
+            node_output = history["outputs"][node_id]
+            if "images" in node_output:
+                for image in node_output["images"]:
+                    url = get_image_url(
+                        image["filename"], image["subfolder"], image["type"], base_url
+                    )
+                    output_images.append({"url": url})
+    return {"data": output_images}
+
+
+class ImageGenerationPayload(BaseModel):
+    prompt: str
+    negative_prompt: Optional[str] = ""
+    steps: Optional[int] = None
+    seed: Optional[int] = None
+    width: int
+    height: int
+    n: int = 1
+
+
+def comfyui_generate_image(
+    model: str, payload: ImageGenerationPayload, client_id, base_url
+):
+    host = base_url.replace("http://", "").replace("https://", "")
+
+    comfyui_prompt = json.loads(COMFYUI_DEFAULT_PROMPT)
+
+    comfyui_prompt["4"]["inputs"]["ckpt_name"] = model
+    comfyui_prompt["5"]["inputs"]["batch_size"] = payload.n
+    comfyui_prompt["5"]["inputs"]["width"] = payload.width
+    comfyui_prompt["5"]["inputs"]["height"] = payload.height
+
+    # set the text prompt for our positive CLIPTextEncode
+    comfyui_prompt["6"]["inputs"]["text"] = payload.prompt
+    comfyui_prompt["7"]["inputs"]["text"] = payload.negative_prompt
+
+    if payload.steps:
+        comfyui_prompt["3"]["inputs"]["steps"] = payload.steps
+
+    comfyui_prompt["3"]["inputs"]["seed"] = (
+        payload.seed if payload.seed else random.randint(0, 18446744073709551614)
+    )
+
+    try:
+        ws = websocket.WebSocket()
+        ws.connect(f"ws://{host}/ws?clientId={client_id}")
+        print("WebSocket connection established.")
+    except Exception as e:
+        print(f"Failed to connect to WebSocket server: {e}")
+        return None
+
+    try:
+        images = get_images(ws, comfyui_prompt, client_id, base_url)
+    except Exception as e:
+        print(f"Error while receiving images: {e}")
+        images = None
+
+    ws.close()
+
+    return images

+ 1 - 0
backend/config.py

@@ -376,3 +376,4 @@ WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models"
 ####################################
 ####################################
 
 
 AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "")
 AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "")
+COMFYUI_BASE_URL = os.getenv("COMFYUI_BASE_URL", "")

+ 5 - 5
src/lib/apis/images/index.ts

@@ -139,7 +139,7 @@ export const updateOpenAIKey = async (token: string = '', key: string) => {
 	return res.OPENAI_API_KEY;
 	return res.OPENAI_API_KEY;
 };
 };
 
 
-export const getAUTOMATIC1111Url = async (token: string = '') => {
+export const getImageGenerationEngineUrls = async (token: string = '') => {
 	let error = null;
 	let error = null;
 
 
 	const res = await fetch(`${IMAGES_API_BASE_URL}/url`, {
 	const res = await fetch(`${IMAGES_API_BASE_URL}/url`, {
@@ -168,10 +168,10 @@ export const getAUTOMATIC1111Url = async (token: string = '') => {
 		throw error;
 		throw error;
 	}
 	}
 
 
-	return res.AUTOMATIC1111_BASE_URL;
+	return res;
 };
 };
 
 
-export const updateAUTOMATIC1111Url = async (token: string = '', url: string) => {
+export const updateImageGenerationEngineUrls = async (token: string = '', urls: object = {}) => {
 	let error = null;
 	let error = null;
 
 
 	const res = await fetch(`${IMAGES_API_BASE_URL}/url/update`, {
 	const res = await fetch(`${IMAGES_API_BASE_URL}/url/update`, {
@@ -182,7 +182,7 @@ export const updateAUTOMATIC1111Url = async (token: string = '', url: string) =>
 			...(token && { authorization: `Bearer ${token}` })
 			...(token && { authorization: `Bearer ${token}` })
 		},
 		},
 		body: JSON.stringify({
 		body: JSON.stringify({
-			url: url
+			...urls
 		})
 		})
 	})
 	})
 		.then(async (res) => {
 		.then(async (res) => {
@@ -203,7 +203,7 @@ export const updateAUTOMATIC1111Url = async (token: string = '', url: string) =>
 		throw error;
 		throw error;
 	}
 	}
 
 
-	return res.AUTOMATIC1111_BASE_URL;
+	return res;
 };
 };
 
 
 export const getImageSize = async (token: string = '') => {
 export const getImageSize = async (token: string = '') => {

+ 81 - 18
src/lib/components/chat/Settings/Images.svelte

@@ -4,14 +4,14 @@
 	import { createEventDispatcher, onMount, getContext } from 'svelte';
 	import { createEventDispatcher, onMount, getContext } from 'svelte';
 	import { config, user } from '$lib/stores';
 	import { config, user } from '$lib/stores';
 	import {
 	import {
-		getAUTOMATIC1111Url,
 		getImageGenerationModels,
 		getImageGenerationModels,
 		getDefaultImageGenerationModel,
 		getDefaultImageGenerationModel,
 		updateDefaultImageGenerationModel,
 		updateDefaultImageGenerationModel,
 		getImageSize,
 		getImageSize,
 		getImageGenerationConfig,
 		getImageGenerationConfig,
 		updateImageGenerationConfig,
 		updateImageGenerationConfig,
-		updateAUTOMATIC1111Url,
+		getImageGenerationEngineUrls,
+		updateImageGenerationEngineUrls,
 		updateImageSize,
 		updateImageSize,
 		getImageSteps,
 		getImageSteps,
 		updateImageSteps,
 		updateImageSteps,
@@ -31,6 +31,8 @@
 	let enableImageGeneration = false;
 	let enableImageGeneration = false;
 
 
 	let AUTOMATIC1111_BASE_URL = '';
 	let AUTOMATIC1111_BASE_URL = '';
+	let COMFYUI_BASE_URL = '';
+
 	let OPENAI_API_KEY = '';
 	let OPENAI_API_KEY = '';
 
 
 	let selectedModel = '';
 	let selectedModel = '';
@@ -49,24 +51,47 @@
 		});
 		});
 	};
 	};
 
 
-	const updateAUTOMATIC1111UrlHandler = async () => {
-		const res = await updateAUTOMATIC1111Url(localStorage.token, AUTOMATIC1111_BASE_URL).catch(
-			(error) => {
+	const updateUrlHandler = async () => {
+		if (imageGenerationEngine === 'comfyui') {
+			const res = await updateImageGenerationEngineUrls(localStorage.token, {
+				COMFYUI_BASE_URL: COMFYUI_BASE_URL
+			}).catch((error) => {
 				toast.error(error);
 				toast.error(error);
+
+				console.log(error);
 				return null;
 				return null;
-			}
-		);
+			});
 
 
-		if (res) {
-			AUTOMATIC1111_BASE_URL = res;
+			if (res) {
+				COMFYUI_BASE_URL = res.COMFYUI_BASE_URL;
 
 
-			await getModels();
+				await getModels();
 
 
-			if (models) {
-				toast.success($i18n.t('Server connection verified'));
+				if (models) {
+					toast.success($i18n.t('Server connection verified'));
+				}
+			} else {
+				({ COMFYUI_BASE_URL } = await getImageGenerationEngineUrls(localStorage.token));
 			}
 			}
 		} else {
 		} else {
-			AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token);
+			const res = await updateImageGenerationEngineUrls(localStorage.token, {
+				AUTOMATIC1111_BASE_URL: AUTOMATIC1111_BASE_URL
+			}).catch((error) => {
+				toast.error(error);
+				return null;
+			});
+
+			if (res) {
+				AUTOMATIC1111_BASE_URL = res.AUTOMATIC1111_BASE_URL;
+
+				await getModels();
+
+				if (models) {
+					toast.success($i18n.t('Server connection verified'));
+				}
+			} else {
+				({ AUTOMATIC1111_BASE_URL } = await getImageGenerationEngineUrls(localStorage.token));
+			}
 		}
 		}
 	};
 	};
 	const updateImageGeneration = async () => {
 	const updateImageGeneration = async () => {
@@ -101,7 +126,11 @@
 				imageGenerationEngine = res.engine;
 				imageGenerationEngine = res.engine;
 				enableImageGeneration = res.enabled;
 				enableImageGeneration = res.enabled;
 			}
 			}
-			AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token);
+			const URLS = await getImageGenerationEngineUrls(localStorage.token);
+
+			AUTOMATIC1111_BASE_URL = URLS.AUTOMATIC1111_BASE_URL;
+			COMFYUI_BASE_URL = URLS.COMFYUI_BASE_URL;
+
 			OPENAI_API_KEY = await getOpenAIKey(localStorage.token);
 			OPENAI_API_KEY = await getOpenAIKey(localStorage.token);
 
 
 			imageSize = await getImageSize(localStorage.token);
 			imageSize = await getImageSize(localStorage.token);
@@ -154,6 +183,7 @@
 						}}
 						}}
 					>
 					>
 						<option value="">{$i18n.t('Default (Automatic1111)')}</option>
 						<option value="">{$i18n.t('Default (Automatic1111)')}</option>
+						<option value="comfyui">{$i18n.t('ComfyUI')}</option>
 						<option value="openai">{$i18n.t('Open AI (Dall-E)')}</option>
 						<option value="openai">{$i18n.t('Open AI (Dall-E)')}</option>
 					</select>
 					</select>
 				</div>
 				</div>
@@ -171,6 +201,9 @@
 							if (imageGenerationEngine === '' && AUTOMATIC1111_BASE_URL === '') {
 							if (imageGenerationEngine === '' && AUTOMATIC1111_BASE_URL === '') {
 								toast.error($i18n.t('AUTOMATIC1111 Base URL is required.'));
 								toast.error($i18n.t('AUTOMATIC1111 Base URL is required.'));
 								enableImageGeneration = false;
 								enableImageGeneration = false;
+							} else if (imageGenerationEngine === 'comfyui' && COMFYUI_BASE_URL === '') {
+								toast.error($i18n.t('ComfyUI Base URL is required.'));
+								enableImageGeneration = false;
 							} else if (imageGenerationEngine === 'openai' && OPENAI_API_KEY === '') {
 							} else if (imageGenerationEngine === 'openai' && OPENAI_API_KEY === '') {
 								toast.error($i18n.t('OpenAI API Key is required.'));
 								toast.error($i18n.t('OpenAI API Key is required.'));
 								enableImageGeneration = false;
 								enableImageGeneration = false;
@@ -204,12 +237,10 @@
 					/>
 					/>
 				</div>
 				</div>
 				<button
 				<button
-					class="px-3 bg-gray-200 hover:bg-gray-300 dark:bg-gray-600 dark:hover:bg-gray-700 rounded-lg transition"
+					class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
 					type="button"
 					type="button"
 					on:click={() => {
 					on:click={() => {
-						// updateOllamaAPIUrlHandler();
-
-						updateAUTOMATIC1111UrlHandler();
+						updateUrlHandler();
 					}}
 					}}
 				>
 				>
 					<svg
 					<svg
@@ -237,6 +268,37 @@
 					{$i18n.t('(e.g. `sh webui.sh --api`)')}
 					{$i18n.t('(e.g. `sh webui.sh --api`)')}
 				</a>
 				</a>
 			</div>
 			</div>
+		{:else if imageGenerationEngine === 'comfyui'}
+			<div class=" mb-2.5 text-sm font-medium">{$i18n.t('ComfyUI Base URL')}</div>
+			<div class="flex w-full">
+				<div class="flex-1 mr-2">
+					<input
+						class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+						placeholder={$i18n.t('Enter URL (e.g. http://127.0.0.1:7860/)')}
+						bind:value={COMFYUI_BASE_URL}
+					/>
+				</div>
+				<button
+					class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
+					type="button"
+					on:click={() => {
+						updateUrlHandler();
+					}}
+				>
+					<svg
+						xmlns="http://www.w3.org/2000/svg"
+						viewBox="0 0 20 20"
+						fill="currentColor"
+						class="w-4 h-4"
+					>
+						<path
+							fill-rule="evenodd"
+							d="M15.312 11.424a5.5 5.5 0 01-9.201 2.466l-.312-.311h2.433a.75.75 0 000-1.5H3.989a.75.75 0 00-.75.75v4.242a.75.75 0 001.5 0v-2.43l.31.31a7 7 0 0011.712-3.138.75.75 0 00-1.449-.39zm1.23-3.723a.75.75 0 00.219-.53V2.929a.75.75 0 00-1.5 0V5.36l-.31-.31A7 7 0 003.239 8.188a.75.75 0 101.448.389A5.5 5.5 0 0113.89 6.11l.311.31h-2.432a.75.75 0 000 1.5h4.243a.75.75 0 00.53-.219z"
+							clip-rule="evenodd"
+						/>
+					</svg>
+				</button>
+			</div>
 		{:else if imageGenerationEngine === 'openai'}
 		{:else if imageGenerationEngine === 'openai'}
 			<div class=" mb-2.5 text-sm font-medium">{$i18n.t('OpenAI API Key')}</div>
 			<div class=" mb-2.5 text-sm font-medium">{$i18n.t('OpenAI API Key')}</div>
 			<div class="flex w-full">
 			<div class="flex w-full">
@@ -261,6 +323,7 @@
 							class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
 							class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
 							bind:value={selectedModel}
 							bind:value={selectedModel}
 							placeholder={$i18n.t('Select a model')}
 							placeholder={$i18n.t('Select a model')}
+							required
 						>
 						>
 							{#if !selectedModel}
 							{#if !selectedModel}
 								<option value="" disabled selected>{$i18n.t('Select a model')}</option>
 								<option value="" disabled selected>{$i18n.t('Select a model')}</option>

+ 17 - 4
src/lib/components/common/ImagePreview.svelte

@@ -2,6 +2,22 @@
 	export let show = false;
 	export let show = false;
 	export let src = '';
 	export let src = '';
 	export let alt = '';
 	export let alt = '';
+
+	const downloadImage = (url, filename) => {
+		fetch(url)
+			.then((response) => response.blob())
+			.then((blob) => {
+				const objectUrl = window.URL.createObjectURL(blob);
+				const link = document.createElement('a');
+				link.href = objectUrl;
+				link.download = filename;
+				document.body.appendChild(link);
+				link.click();
+				document.body.removeChild(link);
+				window.URL.revokeObjectURL(objectUrl);
+			})
+			.catch((error) => console.error('Error downloading image:', error));
+	};
 </script>
 </script>
 
 
 {#if show}
 {#if show}
@@ -35,10 +51,7 @@
 				<button
 				<button
 					class=" p-5"
 					class=" p-5"
 					on:click={() => {
 					on:click={() => {
-						const a = document.createElement('a');
-						a.href = src;
-						a.download = 'Image.png';
-						a.click();
+						downloadImage(src, 'Image.png');
 					}}
 					}}
 				>
 				>
 					<svg
 					<svg