comfyui.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. import asyncio
  2. import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
  3. import uuid
  4. import json
  5. import urllib.request
  6. import urllib.parse
  7. import random
  8. import logging
  9. from config import SRC_LOG_LEVELS
  10. log = logging.getLogger(__name__)
  11. log.setLevel(SRC_LOG_LEVELS["COMFYUI"])
  12. from pydantic import BaseModel
  13. from typing import Optional
  14. COMFYUI_DEFAULT_PROMPT = """
  15. {
  16. "3": {
  17. "inputs": {
  18. "seed": 0,
  19. "steps": 20,
  20. "cfg": 8,
  21. "sampler_name": "euler",
  22. "scheduler": "normal",
  23. "denoise": 1,
  24. "model": [
  25. "4",
  26. 0
  27. ],
  28. "positive": [
  29. "6",
  30. 0
  31. ],
  32. "negative": [
  33. "7",
  34. 0
  35. ],
  36. "latent_image": [
  37. "5",
  38. 0
  39. ]
  40. },
  41. "class_type": "KSampler",
  42. "_meta": {
  43. "title": "KSampler"
  44. }
  45. },
  46. "4": {
  47. "inputs": {
  48. "ckpt_name": "model.safetensors"
  49. },
  50. "class_type": "CheckpointLoaderSimple",
  51. "_meta": {
  52. "title": "Load Checkpoint"
  53. }
  54. },
  55. "5": {
  56. "inputs": {
  57. "width": 512,
  58. "height": 512,
  59. "batch_size": 1
  60. },
  61. "class_type": "EmptyLatentImage",
  62. "_meta": {
  63. "title": "Empty Latent Image"
  64. }
  65. },
  66. "6": {
  67. "inputs": {
  68. "text": "Prompt",
  69. "clip": [
  70. "4",
  71. 1
  72. ]
  73. },
  74. "class_type": "CLIPTextEncode",
  75. "_meta": {
  76. "title": "CLIP Text Encode (Prompt)"
  77. }
  78. },
  79. "7": {
  80. "inputs": {
  81. "text": "Negative Prompt",
  82. "clip": [
  83. "4",
  84. 1
  85. ]
  86. },
  87. "class_type": "CLIPTextEncode",
  88. "_meta": {
  89. "title": "CLIP Text Encode (Prompt)"
  90. }
  91. },
  92. "8": {
  93. "inputs": {
  94. "samples": [
  95. "3",
  96. 0
  97. ],
  98. "vae": [
  99. "4",
  100. 2
  101. ]
  102. },
  103. "class_type": "VAEDecode",
  104. "_meta": {
  105. "title": "VAE Decode"
  106. }
  107. },
  108. "9": {
  109. "inputs": {
  110. "filename_prefix": "ComfyUI",
  111. "images": [
  112. "8",
  113. 0
  114. ]
  115. },
  116. "class_type": "SaveImage",
  117. "_meta": {
  118. "title": "Save Image"
  119. }
  120. }
  121. }
  122. """
  123. FLUX_DEFAULT_PROMPT = """
  124. {
  125. "5": {
  126. "inputs": {
  127. "width": 1024,
  128. "height": 1024,
  129. "batch_size": 1
  130. },
  131. "class_type": "EmptyLatentImage"
  132. },
  133. "6": {
  134. "inputs": {
  135. "text": "Input Text Here",
  136. "clip": [
  137. "11",
  138. 0
  139. ]
  140. },
  141. "class_type": "CLIPTextEncode"
  142. },
  143. "8": {
  144. "inputs": {
  145. "samples": [
  146. "13",
  147. 0
  148. ],
  149. "vae": [
  150. "10",
  151. 0
  152. ]
  153. },
  154. "class_type": "VAEDecode"
  155. },
  156. "9": {
  157. "inputs": {
  158. "filename_prefix": "ComfyUI",
  159. "images": [
  160. "8",
  161. 0
  162. ]
  163. },
  164. "class_type": "SaveImage"
  165. },
  166. "10": {
  167. "inputs": {
  168. "vae_name": "ae.sft"
  169. },
  170. "class_type": "VAELoader"
  171. },
  172. "11": {
  173. "inputs": {
  174. "clip_name1": "clip_l.safetensors",
  175. "clip_name2": "t5xxl_fp16.safetensors",
  176. "type": "flux"
  177. },
  178. "class_type": "DualCLIPLoader"
  179. },
  180. "12": {
  181. "inputs": {
  182. "unet_name": "flux1-dev.sft",
  183. "weight_dtype": "default"
  184. },
  185. "class_type": "UNETLoader"
  186. },
  187. "13": {
  188. "inputs": {
  189. "noise": [
  190. "25",
  191. 0
  192. ],
  193. "guider": [
  194. "22",
  195. 0
  196. ],
  197. "sampler": [
  198. "16",
  199. 0
  200. ],
  201. "sigmas": [
  202. "17",
  203. 0
  204. ],
  205. "latent_image": [
  206. "5",
  207. 0
  208. ]
  209. },
  210. "class_type": "SamplerCustomAdvanced"
  211. },
  212. "16": {
  213. "inputs": {
  214. "sampler_name": "euler"
  215. },
  216. "class_type": "KSamplerSelect"
  217. },
  218. "17": {
  219. "inputs": {
  220. "scheduler": "simple",
  221. "steps": 20,
  222. "denoise": 1,
  223. "model": [
  224. "12",
  225. 0
  226. ]
  227. },
  228. "class_type": "BasicScheduler"
  229. },
  230. "22": {
  231. "inputs": {
  232. "model": [
  233. "12",
  234. 0
  235. ],
  236. "conditioning": [
  237. "6",
  238. 0
  239. ]
  240. },
  241. "class_type": "BasicGuider"
  242. },
  243. "25": {
  244. "inputs": {
  245. "noise_seed": 778937779713005
  246. },
  247. "class_type": "RandomNoise"
  248. }
  249. }
  250. """
  251. def queue_prompt(prompt, client_id, base_url):
  252. log.info("queue_prompt")
  253. p = {"prompt": prompt, "client_id": client_id}
  254. data = json.dumps(p).encode("utf-8")
  255. req = urllib.request.Request(f"{base_url}/prompt", data=data)
  256. return json.loads(urllib.request.urlopen(req).read())
  257. def get_image(filename, subfolder, folder_type, base_url):
  258. log.info("get_image")
  259. data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
  260. url_values = urllib.parse.urlencode(data)
  261. with urllib.request.urlopen(f"{base_url}/view?{url_values}") as response:
  262. return response.read()
  263. def get_image_url(filename, subfolder, folder_type, base_url):
  264. log.info("get_image")
  265. data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
  266. url_values = urllib.parse.urlencode(data)
  267. return f"{base_url}/view?{url_values}"
  268. def get_history(prompt_id, base_url):
  269. log.info("get_history")
  270. with urllib.request.urlopen(f"{base_url}/history/{prompt_id}") as response:
  271. return json.loads(response.read())
  272. def get_images(ws, prompt, client_id, base_url):
  273. prompt_id = queue_prompt(prompt, client_id, base_url)["prompt_id"]
  274. output_images = []
  275. while True:
  276. out = ws.recv()
  277. if isinstance(out, str):
  278. message = json.loads(out)
  279. if message["type"] == "executing":
  280. data = message["data"]
  281. if data["node"] is None and data["prompt_id"] == prompt_id:
  282. break # Execution is done
  283. else:
  284. continue # previews are binary data
  285. history = get_history(prompt_id, base_url)[prompt_id]
  286. for o in history["outputs"]:
  287. for node_id in history["outputs"]:
  288. node_output = history["outputs"][node_id]
  289. if "images" in node_output:
  290. for image in node_output["images"]:
  291. url = get_image_url(
  292. image["filename"], image["subfolder"], image["type"], base_url
  293. )
  294. output_images.append({"url": url})
  295. return {"data": output_images}
  296. class ImageGenerationPayload(BaseModel):
  297. prompt: str
  298. negative_prompt: Optional[str] = ""
  299. steps: Optional[int] = None
  300. seed: Optional[int] = None
  301. width: int
  302. height: int
  303. n: int = 1
  304. cfg_scale: Optional[float] = None
  305. sampler: Optional[str] = None
  306. scheduler: Optional[str] = None
  307. sd3: Optional[bool] = None
  308. flux: Optional[bool] = None
  309. flux_weight_dtype: Optional[str] = None
  310. flux_fp8_clip: Optional[bool] = None
  311. async def comfyui_generate_image(
  312. model: str, payload: ImageGenerationPayload, client_id, base_url
  313. ):
  314. ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://")
  315. comfyui_prompt = json.loads(COMFYUI_DEFAULT_PROMPT)
  316. if payload.cfg_scale:
  317. comfyui_prompt["3"]["inputs"]["cfg"] = payload.cfg_scale
  318. if payload.sampler:
  319. comfyui_prompt["3"]["inputs"]["sampler"] = payload.sampler
  320. if payload.scheduler:
  321. comfyui_prompt["3"]["inputs"]["scheduler"] = payload.scheduler
  322. if payload.sd3:
  323. comfyui_prompt["5"]["class_type"] = "EmptySD3LatentImage"
  324. if payload.steps:
  325. comfyui_prompt["3"]["inputs"]["steps"] = payload.steps
  326. comfyui_prompt["4"]["inputs"]["ckpt_name"] = model
  327. comfyui_prompt["7"]["inputs"]["text"] = payload.negative_prompt
  328. comfyui_prompt["3"]["inputs"]["seed"] = (
  329. payload.seed if payload.seed else random.randint(0, 18446744073709551614)
  330. )
  331. # as Flux uses a completely different workflow, we must treat it specially
  332. if payload.flux:
  333. comfyui_prompt = json.loads(FLUX_DEFAULT_PROMPT)
  334. comfyui_prompt["12"]["inputs"]["unet_name"] = model
  335. comfyui_prompt["25"]["inputs"]["noise_seed"] = (
  336. payload.seed if payload.seed else random.randint(0, 18446744073709551614)
  337. )
  338. if payload.sampler:
  339. comfyui_prompt["16"]["inputs"]["sampler_name"] = payload.sampler
  340. if payload.steps:
  341. comfyui_prompt["17"]["inputs"]["steps"] = payload.steps
  342. if payload.scheduler:
  343. comfyui_prompt["17"]["inputs"]["scheduler"] = payload.scheduler
  344. if payload.flux_weight_dtype:
  345. comfyui_prompt["12"]["inputs"]["weight_dtype"] = payload.flux_weight_dtype
  346. if payload.flux_fp8_clip:
  347. comfyui_prompt["11"]["inputs"][
  348. "clip_name2"
  349. ] = "t5xxl_fp8_e4m3fn.safetensors"
  350. comfyui_prompt["5"]["inputs"]["batch_size"] = payload.n
  351. comfyui_prompt["5"]["inputs"]["width"] = payload.width
  352. comfyui_prompt["5"]["inputs"]["height"] = payload.height
  353. # set the text prompt for our positive CLIPTextEncode
  354. comfyui_prompt["6"]["inputs"]["text"] = payload.prompt
  355. try:
  356. ws = websocket.WebSocket()
  357. ws.connect(f"{ws_url}/ws?clientId={client_id}")
  358. log.info("WebSocket connection established.")
  359. except Exception as e:
  360. log.exception(f"Failed to connect to WebSocket server: {e}")
  361. return None
  362. try:
  363. images = await asyncio.to_thread(get_images, ws, comfyui_prompt, client_id, base_url)
  364. except Exception as e:
  365. log.exception(f"Error while receiving images: {e}")
  366. images = None
  367. ws.close()
  368. return images