Преглед на файлове

feat: comfyui integration

Timothy J. Baek преди 1 година
родител
ревизия
98624a406f
променени са 4 файла, в които са добавени 315 реда и са изтрити 8 реда
  1. 69 4
      backend/apps/images/main.py
  2. 228 0
      backend/apps/images/utils/comfyui.py
  3. 1 0
      src/lib/components/chat/Settings/Images.svelte
  4. 17 4
      src/lib/components/common/ImagePreview.svelte

+ 69 - 4
backend/apps/images/main.py

@@ -18,6 +18,8 @@ from utils.utils import (
     get_current_user,
     get_admin_user,
 )
+
+from apps.images.utils.comfyui import ImageGenerationPayload, comfyui_generate_image
 from utils.misc import calculate_sha256
 from typing import Optional
 from pydantic import BaseModel
@@ -105,7 +107,12 @@ async def update_engine_url(
         app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
     else:
         url = form_data.COMFYUI_BASE_URL.strip("/")
-        app.state.COMFYUI_BASE_URL = url
+
+        try:
+            r = requests.head(url)
+            app.state.COMFYUI_BASE_URL = url
+        except Exception as e:
+            raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
 
     return {
         "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
@@ -232,6 +239,8 @@ async def get_default_model(user=Depends(get_admin_user)):
     try:
         if app.state.ENGINE == "openai":
             return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"}
+        elif app.state.ENGINE == "comfyui":
+            return {"model": app.state.MODEL if app.state.MODEL else ""}
         else:
             r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
             options = r.json()
@@ -246,10 +255,12 @@ class UpdateModelForm(BaseModel):
 
 
 def set_model_handler(model: str):
-
     if app.state.ENGINE == "openai":
         app.state.MODEL = model
         return app.state.MODEL
+    if app.state.ENGINE == "comfyui":
+        app.state.MODEL = model
+        return app.state.MODEL
     else:
         r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
         options = r.json()
@@ -297,12 +308,31 @@ def save_b64_image(b64_str):
         return None
 
 
+def save_url_image(url):
+    image_id = str(uuid.uuid4())
+    file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
+
+    try:
+        r = requests.get(url)
+        r.raise_for_status()
+
+        with open(file_path, "wb") as image_file:
+            image_file.write(r.content)
+
+        return image_id
+    except Exception as e:
+        print(f"Error saving image: {e}")
+        return None
+
+
 @app.post("/generations")
 def generate_image(
     form_data: GenerateImageForm,
     user=Depends(get_current_user),
 ):
 
+    width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
+
     r = None
     try:
         if app.state.ENGINE == "openai":
@@ -340,12 +370,47 @@ def generate_image(
 
             return images
 
+        elif app.state.ENGINE == "comfyui":
+
+            data = {
+                "prompt": form_data.prompt,
+                "width": width,
+                "height": height,
+                "n": form_data.n,
+            }
+
+            if app.state.IMAGE_STEPS != None:
+                data["steps"] = app.state.IMAGE_STEPS
+
+            if form_data.negative_prompt != None:
+                data["negative_prompt"] = form_data.negative_prompt
+
+            data = ImageGenerationPayload(**data)
+
+            res = comfyui_generate_image(
+                app.state.MODEL,
+                data,
+                user.id,
+                app.state.COMFYUI_BASE_URL,
+            )
+            print(res)
+
+            images = []
+
+            for image in res["data"]:
+                image_id = save_url_image(image["url"])
+                images.append({"url": f"/cache/image/generations/{image_id}.png"})
+                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
+
+                with open(file_body_path, "w") as f:
+                    json.dump(data.model_dump(exclude_none=True), f)
+
+            print(images)
+            return images
         else:
             if form_data.model:
                 set_model_handler(form_data.model)
 
-            width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
-
             data = {
                 "prompt": form_data.prompt,
                 "batch_size": form_data.n,

+ 228 - 0
backend/apps/images/utils/comfyui.py

@@ -0,0 +1,228 @@
+import websocket  # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
+import uuid
+import json
+import urllib.request
+import urllib.parse
+import random
+
+from pydantic import BaseModel
+
+from typing import Optional
+
+COMFYUI_DEFAULT_PROMPT = """
+{
+  "3": {
+    "inputs": {
+      "seed": 0,
+      "steps": 20,
+      "cfg": 8,
+      "sampler_name": "euler",
+      "scheduler": "normal",
+      "denoise": 1,
+      "model": [
+        "4",
+        0
+      ],
+      "positive": [
+        "6",
+        0
+      ],
+      "negative": [
+        "7",
+        0
+      ],
+      "latent_image": [
+        "5",
+        0
+      ]
+    },
+    "class_type": "KSampler",
+    "_meta": {
+      "title": "KSampler"
+    }
+  },
+  "4": {
+    "inputs": {
+      "ckpt_name": "model.safetensors"
+    },
+    "class_type": "CheckpointLoaderSimple",
+    "_meta": {
+      "title": "Load Checkpoint"
+    }
+  },
+  "5": {
+    "inputs": {
+      "width": 512,
+      "height": 512,
+      "batch_size": 1
+    },
+    "class_type": "EmptyLatentImage",
+    "_meta": {
+      "title": "Empty Latent Image"
+    }
+  },
+  "6": {
+    "inputs": {
+      "text": "Prompt",
+      "clip": [
+        "4",
+        1
+      ]
+    },
+    "class_type": "CLIPTextEncode",
+    "_meta": {
+      "title": "CLIP Text Encode (Prompt)"
+    }
+  },
+  "7": {
+    "inputs": {
+      "text": "Negative Prompt",
+      "clip": [
+        "4",
+        1
+      ]
+    },
+    "class_type": "CLIPTextEncode",
+    "_meta": {
+      "title": "CLIP Text Encode (Prompt)"
+    }
+  },
+  "8": {
+    "inputs": {
+      "samples": [
+        "3",
+        0
+      ],
+      "vae": [
+        "4",
+        2
+      ]
+    },
+    "class_type": "VAEDecode",
+    "_meta": {
+      "title": "VAE Decode"
+    }
+  },
+  "9": {
+    "inputs": {
+      "filename_prefix": "ComfyUI",
+      "images": [
+        "8",
+        0
+      ]
+    },
+    "class_type": "SaveImage",
+    "_meta": {
+      "title": "Save Image"
+    }
+  }
+}
+"""
+
+
+def queue_prompt(prompt, client_id, base_url):
+    print("queue_prompt")
+    p = {"prompt": prompt, "client_id": client_id}
+    data = json.dumps(p).encode("utf-8")
+    req = urllib.request.Request(f"{base_url}/prompt", data=data)
+    return json.loads(urllib.request.urlopen(req).read())
+
+
+def get_image(filename, subfolder, folder_type, base_url):
+    print("get_image")
+    data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
+    url_values = urllib.parse.urlencode(data)
+    with urllib.request.urlopen(f"{base_url}/view?{url_values}") as response:
+        return response.read()
+
+
+def get_image_url(filename, subfolder, folder_type, base_url):
+    print("get_image")
+    data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
+    url_values = urllib.parse.urlencode(data)
+    return f"{base_url}/view?{url_values}"
+
+
+def get_history(prompt_id, base_url):
+    print("get_history")
+    with urllib.request.urlopen(f"{base_url}/history/{prompt_id}") as response:
+        return json.loads(response.read())
+
+
+def get_images(ws, prompt, client_id, base_url):
+    prompt_id = queue_prompt(prompt, client_id, base_url)["prompt_id"]
+    output_images = []
+    while True:
+        out = ws.recv()
+        if isinstance(out, str):
+            message = json.loads(out)
+            if message["type"] == "executing":
+                data = message["data"]
+                if data["node"] is None and data["prompt_id"] == prompt_id:
+                    break  # Execution is done
+        else:
+            continue  # previews are binary data
+
+    history = get_history(prompt_id, base_url)[prompt_id]
+    for o in history["outputs"]:
+        for node_id in history["outputs"]:
+            node_output = history["outputs"][node_id]
+            if "images" in node_output:
+                for image in node_output["images"]:
+                    url = get_image_url(
+                        image["filename"], image["subfolder"], image["type"], base_url
+                    )
+                    output_images.append({"url": url})
+    return {"data": output_images}
+
+
+class ImageGenerationPayload(BaseModel):
+    prompt: str
+    negative_prompt: Optional[str] = ""
+    steps: Optional[int] = None
+    seed: Optional[int] = None
+    width: int
+    height: int
+    n: int = 1
+
+
+def comfyui_generate_image(
+    model: str, payload: ImageGenerationPayload, client_id, base_url
+):
+    host = base_url.replace("http://", "").replace("https://", "")
+
+    comfyui_prompt = json.loads(COMFYUI_DEFAULT_PROMPT)
+
+    comfyui_prompt["4"]["inputs"]["ckpt_name"] = model
+    comfyui_prompt["5"]["inputs"]["batch_size"] = payload.n
+    comfyui_prompt["5"]["inputs"]["width"] = payload.width
+    comfyui_prompt["5"]["inputs"]["height"] = payload.height
+
+    # set the text prompt for our positive CLIPTextEncode
+    comfyui_prompt["6"]["inputs"]["text"] = payload.prompt
+    comfyui_prompt["7"]["inputs"]["text"] = payload.negative_prompt
+
+    if payload.steps:
+        comfyui_prompt["3"]["inputs"]["steps"] = payload.steps
+
+    comfyui_prompt["3"]["inputs"]["seed"] = (
+        payload.seed if payload.seed else random.randint(0, 18446744073709551614)
+    )
+
+    try:
+        ws = websocket.WebSocket()
+        ws.connect(f"ws://{host}/ws?clientId={client_id}")
+        print("WebSocket connection established.")
+    except Exception as e:
+        print(f"Failed to connect to WebSocket server: {e}")
+        return None
+
+    try:
+        images = get_images(ws, comfyui_prompt, client_id, base_url)
+    except Exception as e:
+        print(f"Error while receiving images: {e}")
+        images = None
+
+    ws.close()
+
+    return images

+ 1 - 0
src/lib/components/chat/Settings/Images.svelte

@@ -323,6 +323,7 @@
 							class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
 							bind:value={selectedModel}
 							placeholder={$i18n.t('Select a model')}
+							required
 						>
 							{#if !selectedModel}
 								<option value="" disabled selected>{$i18n.t('Select a model')}</option>

+ 17 - 4
src/lib/components/common/ImagePreview.svelte

@@ -2,6 +2,22 @@
 	export let show = false;
 	export let src = '';
 	export let alt = '';
+
+	const downloadImage = (url, filename) => {
+		fetch(url)
+			.then((response) => response.blob())
+			.then((blob) => {
+				const objectUrl = window.URL.createObjectURL(blob);
+				const link = document.createElement('a');
+				link.href = objectUrl;
+				link.download = filename;
+				document.body.appendChild(link);
+				link.click();
+				document.body.removeChild(link);
+				window.URL.revokeObjectURL(objectUrl);
+			})
+			.catch((error) => console.error('Error downloading image:', error));
+	};
 </script>
 
 {#if show}
@@ -35,10 +51,7 @@
 				<button
 					class=" p-5"
 					on:click={() => {
-						const a = document.createElement('a');
-						a.href = src;
-						a.download = 'Image.png';
-						a.click();
+						downloadImage(src, 'Image.png');
 					}}
 				>
 					<svg