浏览代码

feat: custom `COMFYUI_WORKFLOW`

deprecates several comfyui env vars
Timothy J. Baek 8 月之前
父节点
当前提交
c5310e84db
共有 3 个文件被更改,包括 59 次插入281 次删除
  1. 6 40
      backend/apps/images/main.py
  2. 49 201
      backend/apps/images/utils/comfyui.py
  3. 4 40
      backend/config.py

+ 6 - 40
backend/apps/images/main.py

@@ -14,7 +14,7 @@ from utils.utils import (
     get_admin_user,
     get_admin_user,
 )
 )
 
 
-from apps.images.utils.comfyui import ImageGenerationPayload, comfyui_generate_image
+from apps.images.utils.comfyui import ComfyUIGenerateImageForm, comfyui_generate_image
 from typing import Optional
 from typing import Optional
 from pydantic import BaseModel
 from pydantic import BaseModel
 from pathlib import Path
 from pathlib import Path
@@ -32,20 +32,14 @@ from config import (
     AUTOMATIC1111_BASE_URL,
     AUTOMATIC1111_BASE_URL,
     AUTOMATIC1111_API_AUTH,
     AUTOMATIC1111_API_AUTH,
     COMFYUI_BASE_URL,
     COMFYUI_BASE_URL,
-    COMFYUI_CFG_SCALE,
-    COMFYUI_SAMPLER,
-    COMFYUI_SCHEDULER,
-    COMFYUI_SD3,
-    COMFYUI_FLUX,
-    COMFYUI_FLUX_WEIGHT_DTYPE,
-    COMFYUI_FLUX_FP8_CLIP,
+    COMFYUI_WORKFLOW,
     IMAGES_OPENAI_API_BASE_URL,
     IMAGES_OPENAI_API_BASE_URL,
     IMAGES_OPENAI_API_KEY,
     IMAGES_OPENAI_API_KEY,
     IMAGE_GENERATION_MODEL,
     IMAGE_GENERATION_MODEL,
     IMAGE_SIZE,
     IMAGE_SIZE,
     IMAGE_STEPS,
     IMAGE_STEPS,
-    AppConfig,
     CORS_ALLOW_ORIGIN,
     CORS_ALLOW_ORIGIN,
+    AppConfig,
 )
 )
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)
@@ -76,16 +70,10 @@ app.state.config.MODEL = IMAGE_GENERATION_MODEL
 app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
 app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
 app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
 app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
 app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
 app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
+app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW
 
 
 app.state.config.IMAGE_SIZE = IMAGE_SIZE
 app.state.config.IMAGE_SIZE = IMAGE_SIZE
 app.state.config.IMAGE_STEPS = IMAGE_STEPS
 app.state.config.IMAGE_STEPS = IMAGE_STEPS
-app.state.config.COMFYUI_CFG_SCALE = COMFYUI_CFG_SCALE
-app.state.config.COMFYUI_SAMPLER = COMFYUI_SAMPLER
-app.state.config.COMFYUI_SCHEDULER = COMFYUI_SCHEDULER
-app.state.config.COMFYUI_SD3 = COMFYUI_SD3
-app.state.config.COMFYUI_FLUX = COMFYUI_FLUX
-app.state.config.COMFYUI_FLUX_WEIGHT_DTYPE = COMFYUI_FLUX_WEIGHT_DTYPE
-app.state.config.COMFYUI_FLUX_FP8_CLIP = COMFYUI_FLUX_FP8_CLIP
 
 
 
 
 def get_automatic1111_api_auth():
 def get_automatic1111_api_auth():
@@ -488,32 +476,10 @@ async def image_generations(
             if form_data.negative_prompt is not None:
             if form_data.negative_prompt is not None:
                 data["negative_prompt"] = form_data.negative_prompt
                 data["negative_prompt"] = form_data.negative_prompt
 
 
-            if app.state.config.COMFYUI_CFG_SCALE:
-                data["cfg_scale"] = app.state.config.COMFYUI_CFG_SCALE
-
-            if app.state.config.COMFYUI_SAMPLER is not None:
-                data["sampler"] = app.state.config.COMFYUI_SAMPLER
-
-            if app.state.config.COMFYUI_SCHEDULER is not None:
-                data["scheduler"] = app.state.config.COMFYUI_SCHEDULER
-
-            if app.state.config.COMFYUI_SD3 is not None:
-                data["sd3"] = app.state.config.COMFYUI_SD3
-
-            if app.state.config.COMFYUI_FLUX is not None:
-                data["flux"] = app.state.config.COMFYUI_FLUX
-
-            if app.state.config.COMFYUI_FLUX_WEIGHT_DTYPE is not None:
-                data["flux_weight_dtype"] = app.state.config.COMFYUI_FLUX_WEIGHT_DTYPE
-
-            if app.state.config.COMFYUI_FLUX_FP8_CLIP is not None:
-                data["flux_fp8_clip"] = app.state.config.COMFYUI_FLUX_FP8_CLIP
-
-            data = ImageGenerationPayload(**data)
-
+            form_data = ComfyUIGenerateImageForm(**data)
             res = await comfyui_generate_image(
             res = await comfyui_generate_image(
                 app.state.config.MODEL,
                 app.state.config.MODEL,
-                data,
+                form_data,
                 user.id,
                 user.id,
                 app.state.config.COMFYUI_BASE_URL,
                 app.state.config.COMFYUI_BASE_URL,
             )
             )

+ 49 - 201
backend/apps/images/utils/comfyui.py

@@ -15,7 +15,7 @@ from pydantic import BaseModel
 
 
 from typing import Optional
 from typing import Optional
 
 
-COMFYUI_DEFAULT_PROMPT = """
+COMFYUI_DEFAULT_WORKFLOW = """
 {
 {
   "3": {
   "3": {
     "inputs": {
     "inputs": {
@@ -82,7 +82,7 @@ COMFYUI_DEFAULT_PROMPT = """
   },
   },
   "7": {
   "7": {
     "inputs": {
     "inputs": {
-      "text": "Negative Prompt",
+      "text": "",
       "clip": [
       "clip": [
         "4",
         "4",
         1
         1
@@ -125,135 +125,6 @@ COMFYUI_DEFAULT_PROMPT = """
 }
 }
 """
 """
 
 
-FLUX_DEFAULT_PROMPT = """
-{
-    "5": {
-        "inputs": {
-            "width": 1024,
-            "height": 1024,
-            "batch_size": 1
-        },
-        "class_type": "EmptyLatentImage"
-    },
-    "6": {
-        "inputs": {
-            "text": "Input Text Here",
-            "clip": [
-                "11",
-                0
-            ]
-        },
-        "class_type": "CLIPTextEncode"
-    },
-    "8": {
-        "inputs": {
-            "samples": [
-                "13",
-                0
-            ],
-            "vae": [
-                "10",
-                0
-            ]
-        },
-        "class_type": "VAEDecode"
-    },
-    "9": {
-        "inputs": {
-            "filename_prefix": "ComfyUI",
-            "images": [
-                "8",
-                0
-            ]
-        },
-        "class_type": "SaveImage"
-    },
-    "10": {
-        "inputs": {
-            "vae_name": "ae.safetensors"
-        },
-        "class_type": "VAELoader"
-    },
-    "11": {
-        "inputs": {
-            "clip_name1": "clip_l.safetensors",
-            "clip_name2": "t5xxl_fp16.safetensors",
-            "type": "flux"
-        },
-        "class_type": "DualCLIPLoader"
-    },
-    "12": {
-        "inputs": {
-            "unet_name": "flux1-dev.safetensors",
-            "weight_dtype": "default"
-        },
-        "class_type": "UNETLoader"
-    },
-    "13": {
-        "inputs": {
-            "noise": [
-                "25",
-                0
-            ],
-            "guider": [
-                "22",
-                0
-            ],
-            "sampler": [
-                "16",
-                0
-            ],
-            "sigmas": [
-                "17",
-                0
-            ],
-            "latent_image": [
-                "5",
-                0
-            ]
-        },
-        "class_type": "SamplerCustomAdvanced"
-    },
-    "16": {
-        "inputs": {
-            "sampler_name": "euler"
-        },
-        "class_type": "KSamplerSelect"
-    },
-    "17": {
-        "inputs": {
-            "scheduler": "simple",
-            "steps": 20,
-            "denoise": 1,
-            "model": [
-                "12",
-                0
-            ]
-        },
-        "class_type": "BasicScheduler"
-    },
-    "22": {
-        "inputs": {
-            "model": [
-                "12",
-                0
-            ],
-            "conditioning": [
-                "6",
-                0
-            ]
-        },
-        "class_type": "BasicGuider"
-    },
-    "25": {
-        "inputs": {
-            "noise_seed": 778937779713005
-        },
-        "class_type": "RandomNoise"
-    }
-}
-"""
-
 
 
 def queue_prompt(prompt, client_id, base_url):
 def queue_prompt(prompt, client_id, base_url):
     log.info("queue_prompt")
     log.info("queue_prompt")
@@ -311,82 +182,61 @@ def get_images(ws, prompt, client_id, base_url):
     return {"data": output_images}
     return {"data": output_images}
 
 
 
 
-class ImageGenerationPayload(BaseModel):
+class ComfyUINodeInput(BaseModel):
+    field: Optional[str] = None
+    node_id: str
+    key: Optional[str] = "text"
+    value: Optional[str] = None
+
+
+class ComfyUIWorkflow(BaseModel):
+    workflow: str
+    nodes: list[ComfyUINodeInput]
+
+
+class ComfyUIGenerateImageForm(BaseModel):
+    workflow: ComfyUIWorkflow
+
     prompt: str
     prompt: str
-    negative_prompt: Optional[str] = ""
-    steps: Optional[int] = None
-    seed: Optional[int] = None
+    negative_prompt: Optional[str] = None
     width: int
     width: int
     height: int
     height: int
     n: int = 1
     n: int = 1
-    cfg_scale: Optional[float] = None
-    sampler: Optional[str] = None
-    scheduler: Optional[str] = None
-    sd3: Optional[bool] = None
-    flux: Optional[bool] = None
-    flux_weight_dtype: Optional[str] = None
-    flux_fp8_clip: Optional[bool] = None
+
+    steps: Optional[int] = None
+    seed: Optional[int] = None
 
 
 
 
 async def comfyui_generate_image(
 async def comfyui_generate_image(
-    model: str, payload: ImageGenerationPayload, client_id, base_url
+    model: str, payload: ComfyUIGenerateImageForm, client_id, base_url
 ):
 ):
     ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://")
     ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://")
-
-    comfyui_prompt = json.loads(COMFYUI_DEFAULT_PROMPT)
-
-    if payload.cfg_scale:
-        comfyui_prompt["3"]["inputs"]["cfg"] = payload.cfg_scale
-
-    if payload.sampler:
-        comfyui_prompt["3"]["inputs"]["sampler"] = payload.sampler
-
-    if payload.scheduler:
-        comfyui_prompt["3"]["inputs"]["scheduler"] = payload.scheduler
-
-    if payload.sd3:
-        comfyui_prompt["5"]["class_type"] = "EmptySD3LatentImage"
-
-    if payload.steps:
-        comfyui_prompt["3"]["inputs"]["steps"] = payload.steps
-
-    comfyui_prompt["4"]["inputs"]["ckpt_name"] = model
-    comfyui_prompt["7"]["inputs"]["text"] = payload.negative_prompt
-    comfyui_prompt["3"]["inputs"]["seed"] = (
-        payload.seed if payload.seed else random.randint(0, 18446744073709551614)
-    )
-
-    # as Flux uses a completely different workflow, we must treat it specially
-    if payload.flux:
-        comfyui_prompt = json.loads(FLUX_DEFAULT_PROMPT)
-        comfyui_prompt["12"]["inputs"]["unet_name"] = model
-        comfyui_prompt["25"]["inputs"]["noise_seed"] = (
-            payload.seed if payload.seed else random.randint(0, 18446744073709551614)
-        )
-
-        if payload.sampler:
-            comfyui_prompt["16"]["inputs"]["sampler_name"] = payload.sampler
-
-        if payload.steps:
-            comfyui_prompt["17"]["inputs"]["steps"] = payload.steps
-
-        if payload.scheduler:
-            comfyui_prompt["17"]["inputs"]["scheduler"] = payload.scheduler
-
-        if payload.flux_weight_dtype:
-            comfyui_prompt["12"]["inputs"]["weight_dtype"] = payload.flux_weight_dtype
-
-        if payload.flux_fp8_clip:
-            comfyui_prompt["11"]["inputs"][
-                "clip_name2"
-            ] = "t5xxl_fp8_e4m3fn.safetensors"
-
-    comfyui_prompt["5"]["inputs"]["batch_size"] = payload.n
-    comfyui_prompt["5"]["inputs"]["width"] = payload.width
-    comfyui_prompt["5"]["inputs"]["height"] = payload.height
-
-    # set the text prompt for our positive CLIPTextEncode
-    comfyui_prompt["6"]["inputs"]["text"] = payload.prompt
+    workflow = json.loads(payload.workflow.workflow)
+
+    for node in payload.workflow.nodes:
+        if node.field:
+            if node.field == "model":
+                workflow[node.node_id]["inputs"][node.key] = model
+            elif node.field == "prompt":
+                workflow[node.node_id]["inputs"]["text"] = payload.prompt
+            elif node.field == "negative_prompt":
+                workflow[node.node_id]["inputs"]["text"] = payload.negative_prompt
+            elif node.field == "width":
+                workflow[node.node_id]["inputs"]["width"] = payload.width
+            elif node.field == "height":
+                workflow[node.node_id]["inputs"]["height"] = payload.height
+            elif node.field == "n":
+                workflow[node.node_id]["inputs"]["batch_size"] = payload.n
+            elif node.field == "steps":
+                workflow[node.node_id]["inputs"]["steps"] = payload.steps
+            elif node.field == "seed":
+                workflow[node.node_id]["inputs"]["seed"] = (
+                    payload.seed
+                    if payload.seed
+                    else random.randint(0, 18446744073709551614)
+                )
+        else:
+            workflow[node.node_id]["inputs"][node.key] = node.value
 
 
     try:
     try:
         ws = websocket.WebSocket()
         ws = websocket.WebSocket()
@@ -397,9 +247,7 @@ async def comfyui_generate_image(
         return None
         return None
 
 
     try:
     try:
-        images = await asyncio.to_thread(
-            get_images, ws, comfyui_prompt, client_id, base_url
-        )
+        images = await asyncio.to_thread(get_images, ws, workflow, client_id, base_url)
     except Exception as e:
     except Exception as e:
         log.exception(f"Error while receiving images: {e}")
         log.exception(f"Error while receiving images: {e}")
         images = None
         images = None

+ 4 - 40
backend/config.py

@@ -1342,46 +1342,10 @@ COMFYUI_BASE_URL = PersistentConfig(
     os.getenv("COMFYUI_BASE_URL", ""),
     os.getenv("COMFYUI_BASE_URL", ""),
 )
 )
 
 
-COMFYUI_CFG_SCALE = PersistentConfig(
-    "COMFYUI_CFG_SCALE",
-    "image_generation.comfyui.cfg_scale",
-    os.getenv("COMFYUI_CFG_SCALE", ""),
-)
-
-COMFYUI_SAMPLER = PersistentConfig(
-    "COMFYUI_SAMPLER",
-    "image_generation.comfyui.sampler",
-    os.getenv("COMFYUI_SAMPLER", ""),
-)
-
-COMFYUI_SCHEDULER = PersistentConfig(
-    "COMFYUI_SCHEDULER",
-    "image_generation.comfyui.scheduler",
-    os.getenv("COMFYUI_SCHEDULER", ""),
-)
-
-COMFYUI_SD3 = PersistentConfig(
-    "COMFYUI_SD3",
-    "image_generation.comfyui.sd3",
-    os.environ.get("COMFYUI_SD3", "").lower() == "true",
-)
-
-COMFYUI_FLUX = PersistentConfig(
-    "COMFYUI_FLUX",
-    "image_generation.comfyui.flux",
-    os.environ.get("COMFYUI_FLUX", "").lower() == "true",
-)
-
-COMFYUI_FLUX_WEIGHT_DTYPE = PersistentConfig(
-    "COMFYUI_FLUX_WEIGHT_DTYPE",
-    "image_generation.comfyui.flux_weight_dtype",
-    os.getenv("COMFYUI_FLUX_WEIGHT_DTYPE", ""),
-)
-
-COMFYUI_FLUX_FP8_CLIP = PersistentConfig(
-    "COMFYUI_FLUX_FP8_CLIP",
-    "image_generation.comfyui.flux_fp8_clip",
-    os.environ.get("COMFYUI_FLUX_FP8_CLIP", "").lower() == "true",
+COMFYUI_WORKFLOW = PersistentConfig(
+    "COMFYUI_WORKFLOW",
+    "image_generation.comfyui.workflow",
+    os.getenv("COMFYUI_WORKFLOW", ""),
 )
 )
 
 
 IMAGES_OPENAI_API_BASE_URL = PersistentConfig(
 IMAGES_OPENAI_API_BASE_URL = PersistentConfig(