ソースを参照

added support for the new Flux image gen model using ComfyUI

this commit adds three environment variables:

- COMFYUI_FLUX: determines whether Flux is used, the workflow is completely different so this is necessary.
- COMFYUI_FLUX_WEIGHT_DTYPE: sets the weight precision for Flux. you will probably want to set this to "fp8_e4m3fn" as the fp16 weights take up about 24GB of VRAM. optional, defaults to "default".
- COMFYUI_FLUX_FP8_CLIP: Flux requires two CLIP models downloaded, one of which is available in fp8 and fp16. set to true if you are using the fp8 CLIP weights.
John Karabudak 9 ヶ月 前
コミット
ad6e8edcd3
3 ファイル変更196 行追加9 行削除
  1. 15 1
      backend/apps/images/main.py
  2. 163 8
      backend/apps/images/utils/comfyui.py
  3. 18 0
      backend/config.py

+ 15 - 1
backend/apps/images/main.py

@@ -42,6 +42,9 @@ from config import (
     COMFYUI_SAMPLER,
     COMFYUI_SCHEDULER,
     COMFYUI_SD3,
+    COMFYUI_FLUX,
+    COMFYUI_FLUX_WEIGHT_DTYPE,
+    COMFYUI_FLUX_FP8_CLIP,
     IMAGES_OPENAI_API_BASE_URL,
     IMAGES_OPENAI_API_KEY,
     IMAGE_GENERATION_MODEL,
@@ -85,7 +88,9 @@ app.state.config.COMFYUI_CFG_SCALE = COMFYUI_CFG_SCALE
 app.state.config.COMFYUI_SAMPLER = COMFYUI_SAMPLER
 app.state.config.COMFYUI_SCHEDULER = COMFYUI_SCHEDULER
 app.state.config.COMFYUI_SD3 = COMFYUI_SD3
-
+app.state.config.COMFYUI_FLUX = COMFYUI_FLUX
+app.state.config.COMFYUI_FLUX_WEIGHT_DTYPE = COMFYUI_FLUX_WEIGHT_DTYPE
+app.state.config.COMFYUI_FLUX_FP8_CLIP = COMFYUI_FLUX_FP8_CLIP
 
 def get_automatic1111_api_auth():
     if app.state.config.AUTOMATIC1111_API_AUTH == None:
@@ -497,6 +502,15 @@ async def image_generations(
             if app.state.config.COMFYUI_SD3 is not None:
                 data["sd3"] = app.state.config.COMFYUI_SD3
 
+            if app.state.config.COMFYUI_FLUX is not None:
+                data["flux"] = app.state.config.COMFYUI_FLUX
+
+            if app.state.config.COMFYUI_FLUX_WEIGHT_DTYPE is not None:
+                data["flux_weight_dtype"] = app.state.config.COMFYUI_FLUX_WEIGHT_DTYPE
+
+            if app.state.config.COMFYUI_FLUX_FP8_CLIP is not None:
+                data["flux_fp8_clip"] = app.state.config.COMFYUI_FLUX_FP8_CLIP
+
             data = ImageGenerationPayload(**data)
 
             res = comfyui_generate_image(

+ 163 - 8
backend/apps/images/utils/comfyui.py

@@ -125,6 +125,135 @@ COMFYUI_DEFAULT_PROMPT = """
 }
 """
 
+FLUX_DEFAULT_PROMPT = """
+{
+    "5": {
+        "inputs": {
+            "width": 1024,
+            "height": 1024,
+            "batch_size": 1
+        },
+        "class_type": "EmptyLatentImage"
+    },
+    "6": {
+        "inputs": {
+            "text": "Input Text Here",
+            "clip": [
+                "11",
+                0
+            ]
+        },
+        "class_type": "CLIPTextEncode"
+    },
+    "8": {
+        "inputs": {
+            "samples": [
+                "13",
+                0
+            ],
+            "vae": [
+                "10",
+                0
+            ]
+        },
+        "class_type": "VAEDecode"
+    },
+    "9": {
+        "inputs": {
+            "filename_prefix": "ComfyUI",
+            "images": [
+                "8",
+                0
+            ]
+        },
+        "class_type": "SaveImage"
+    },
+    "10": {
+        "inputs": {
+            "vae_name": "ae.sft"
+        },
+        "class_type": "VAELoader"
+    },
+    "11": {
+        "inputs": {
+            "clip_name1": "clip_l.safetensors",
+            "clip_name2": "t5xxl_fp16.safetensors",
+            "type": "flux"
+        },
+        "class_type": "DualCLIPLoader"
+    },
+    "12": {
+        "inputs": {
+            "unet_name": "flux1-dev.sft",
+            "weight_dtype": "default"
+        },
+        "class_type": "UNETLoader"
+    },
+    "13": {
+        "inputs": {
+            "noise": [
+                "25",
+                0
+            ],
+            "guider": [
+                "22",
+                0
+            ],
+            "sampler": [
+                "16",
+                0
+            ],
+            "sigmas": [
+                "17",
+                0
+            ],
+            "latent_image": [
+                "5",
+                0
+            ]
+        },
+        "class_type": "SamplerCustomAdvanced"
+    },
+    "16": {
+        "inputs": {
+            "sampler_name": "euler"
+        },
+        "class_type": "KSamplerSelect"
+    },
+    "17": {
+        "inputs": {
+            "scheduler": "simple",
+            "steps": 20,
+            "denoise": 1,
+            "model": [
+                "12",
+                0
+            ]
+        },
+        "class_type": "BasicScheduler"
+    },
+    "22": {
+        "inputs": {
+            "model": [
+                "12",
+                0
+            ],
+            "conditioning": [
+                "6",
+                0
+            ]
+        },
+        "class_type": "BasicGuider"
+    },
+    "25": {
+        "inputs": {
+            "noise_seed": 778937779713005
+        },
+        "class_type": "RandomNoise"
+    }
+}
+"""
+
 
 def queue_prompt(prompt, client_id, base_url):
     log.info("queue_prompt")
@@ -194,6 +323,9 @@ class ImageGenerationPayload(BaseModel):
     sampler: Optional[str] = None
     scheduler: Optional[str] = None
     sd3: Optional[bool] = None
+    flux: Optional[bool] = None
+    flux_weight_dtype: Optional[str] = None
+    flux_fp8_clip: Optional[bool] = None
 
 
 def comfyui_generate_image(
@@ -215,21 +347,44 @@ def comfyui_generate_image(
     if payload.sd3:
         comfyui_prompt["5"]["class_type"] = "EmptySD3LatentImage"
 
+    if payload.steps:
+        comfyui_prompt["3"]["inputs"]["steps"] = payload.steps
+
     comfyui_prompt["4"]["inputs"]["ckpt_name"] = model
+    comfyui_prompt["7"]["inputs"]["text"] = payload.negative_prompt
+    comfyui_prompt["3"]["inputs"]["seed"] = (
+        payload.seed if payload.seed else random.randint(0, 18446744073709551614)
+    )
+
+    # as Flux uses a completely different workflow, we must treat it specially
+    if payload.flux:
+        comfyui_prompt = json.loads(FLUX_DEFAULT_PROMPT)
+        comfyui_prompt["12"]["inputs"]["unet_name"] = model
+        comfyui_prompt["25"]["inputs"]["noise_seed"] = (
+            payload.seed if payload.seed else random.randint(0, 18446744073709551614)
+        )
+
+        if payload.sampler:
+            comfyui_prompt["16"]["inputs"]["sampler_name"] = payload.sampler
+
+        if payload.steps:
+            comfyui_prompt["17"]["inputs"]["steps"] = payload.steps
+
+        if payload.scheduler:
+            comfyui_prompt["17"]["inputs"]["scheduler"] = payload.scheduler
+
+        if payload.flux_weight_dtype:
+            comfyui_prompt["12"]["inputs"]["weight_dtype"] = payload.flux_weight_dtype
+        
+        if payload.flux_fp8_clip:
+            comfyui_prompt["11"]["inputs"]["clip_name2"] = "t5xxl_fp8_e4m3fn.safetensors"
+
     comfyui_prompt["5"]["inputs"]["batch_size"] = payload.n
     comfyui_prompt["5"]["inputs"]["width"] = payload.width
     comfyui_prompt["5"]["inputs"]["height"] = payload.height
 
     # set the text prompt for our positive CLIPTextEncode
     comfyui_prompt["6"]["inputs"]["text"] = payload.prompt
-    comfyui_prompt["7"]["inputs"]["text"] = payload.negative_prompt
-
-    if payload.steps:
-        comfyui_prompt["3"]["inputs"]["steps"] = payload.steps
-
-    comfyui_prompt["3"]["inputs"]["seed"] = (
-        payload.seed if payload.seed else random.randint(0, 18446744073709551614)
-    )
 
     try:
         ws = websocket.WebSocket()

+ 18 - 0
backend/config.py

@@ -1302,6 +1302,24 @@ COMFYUI_SD3 = PersistentConfig(
     os.environ.get("COMFYUI_SD3", "").lower() == "true",
 )
 
+COMFYUI_FLUX = PersistentConfig(
+    "COMFYUI_FLUX",
+    "image_generation.comfyui.flux",
+    os.environ.get("COMFYUI_FLUX", "").lower() == "true",
+)
+
+COMFYUI_FLUX_WEIGHT_DTYPE = PersistentConfig(
+    "COMFYUI_FLUX_WEIGHT_DTYPE",
+    "image_generation.comfyui.flux_weight_dtype",
+    os.getenv("COMFYUI_FLUX_WEIGHT_DTYPE", ""),
+)
+
+COMFYUI_FLUX_FP8_CLIP = PersistentConfig(
+    "COMFYUI_FLUX_FP8_CLIP",
+    "image_generation.comfyui.flux_fp8_clip",
+    os.getenv("COMFYUI_FLUX_FP8_CLIP", ""),
+)
+
 IMAGES_OPENAI_API_BASE_URL = PersistentConfig(
     "IMAGES_OPENAI_API_BASE_URL",
     "image_generation.openai.api_base_url",