comfyui.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
  2. import uuid
  3. import json
  4. import urllib.request
  5. import urllib.parse
  6. import random
  7. import logging
  8. from config import SRC_LOG_LEVELS
  9. log = logging.getLogger(__name__)
  10. log.setLevel(SRC_LOG_LEVELS["COMFYUI"])
  11. from pydantic import BaseModel
  12. from typing import Optional
  13. COMFYUI_DEFAULT_PROMPT = """
  14. {
  15. "3": {
  16. "inputs": {
  17. "seed": 0,
  18. "steps": 20,
  19. "cfg": 8,
  20. "sampler_name": "euler",
  21. "scheduler": "normal",
  22. "denoise": 1,
  23. "model": [
  24. "4",
  25. 0
  26. ],
  27. "positive": [
  28. "6",
  29. 0
  30. ],
  31. "negative": [
  32. "7",
  33. 0
  34. ],
  35. "latent_image": [
  36. "5",
  37. 0
  38. ]
  39. },
  40. "class_type": "KSampler",
  41. "_meta": {
  42. "title": "KSampler"
  43. }
  44. },
  45. "4": {
  46. "inputs": {
  47. "ckpt_name": "model.safetensors"
  48. },
  49. "class_type": "CheckpointLoaderSimple",
  50. "_meta": {
  51. "title": "Load Checkpoint"
  52. }
  53. },
  54. "5": {
  55. "inputs": {
  56. "width": 512,
  57. "height": 512,
  58. "batch_size": 1
  59. },
  60. "class_type": "EmptyLatentImage",
  61. "_meta": {
  62. "title": "Empty Latent Image"
  63. }
  64. },
  65. "6": {
  66. "inputs": {
  67. "text": "Prompt",
  68. "clip": [
  69. "4",
  70. 1
  71. ]
  72. },
  73. "class_type": "CLIPTextEncode",
  74. "_meta": {
  75. "title": "CLIP Text Encode (Prompt)"
  76. }
  77. },
  78. "7": {
  79. "inputs": {
  80. "text": "Negative Prompt",
  81. "clip": [
  82. "4",
  83. 1
  84. ]
  85. },
  86. "class_type": "CLIPTextEncode",
  87. "_meta": {
  88. "title": "CLIP Text Encode (Prompt)"
  89. }
  90. },
  91. "8": {
  92. "inputs": {
  93. "samples": [
  94. "3",
  95. 0
  96. ],
  97. "vae": [
  98. "4",
  99. 2
  100. ]
  101. },
  102. "class_type": "VAEDecode",
  103. "_meta": {
  104. "title": "VAE Decode"
  105. }
  106. },
  107. "9": {
  108. "inputs": {
  109. "filename_prefix": "ComfyUI",
  110. "images": [
  111. "8",
  112. 0
  113. ]
  114. },
  115. "class_type": "SaveImage",
  116. "_meta": {
  117. "title": "Save Image"
  118. }
  119. }
  120. }
  121. """
  122. FLUX_DEFAULT_PROMPT = """
  123. {
  124. "5": {
  125. "inputs": {
  126. "width": 1024,
  127. "height": 1024,
  128. "batch_size": 1
  129. },
  130. "class_type": "EmptyLatentImage"
  131. },
  132. "6": {
  133. "inputs": {
  134. "text": "Input Text Here",
  135. "clip": [
  136. "11",
  137. 0
  138. ]
  139. },
  140. "class_type": "CLIPTextEncode"
  141. },
  142. "8": {
  143. "inputs": {
  144. "samples": [
  145. "13",
  146. 0
  147. ],
  148. "vae": [
  149. "10",
  150. 0
  151. ]
  152. },
  153. "class_type": "VAEDecode"
  154. },
  155. "9": {
  156. "inputs": {
  157. "filename_prefix": "ComfyUI",
  158. "images": [
  159. "8",
  160. 0
  161. ]
  162. },
  163. "class_type": "SaveImage"
  164. },
  165. "10": {
  166. "inputs": {
  167. "vae_name": "ae.sft"
  168. },
  169. "class_type": "VAELoader"
  170. },
  171. "11": {
  172. "inputs": {
  173. "clip_name1": "clip_l.safetensors",
  174. "clip_name2": "t5xxl_fp16.safetensors",
  175. "type": "flux"
  176. },
  177. "class_type": "DualCLIPLoader"
  178. },
  179. "12": {
  180. "inputs": {
  181. "unet_name": "flux1-dev.sft",
  182. "weight_dtype": "default"
  183. },
  184. "class_type": "UNETLoader"
  185. },
  186. "13": {
  187. "inputs": {
  188. "noise": [
  189. "25",
  190. 0
  191. ],
  192. "guider": [
  193. "22",
  194. 0
  195. ],
  196. "sampler": [
  197. "16",
  198. 0
  199. ],
  200. "sigmas": [
  201. "17",
  202. 0
  203. ],
  204. "latent_image": [
  205. "5",
  206. 0
  207. ]
  208. },
  209. "class_type": "SamplerCustomAdvanced"
  210. },
  211. "16": {
  212. "inputs": {
  213. "sampler_name": "euler"
  214. },
  215. "class_type": "KSamplerSelect"
  216. },
  217. "17": {
  218. "inputs": {
  219. "scheduler": "simple",
  220. "steps": 20,
  221. "denoise": 1,
  222. "model": [
  223. "12",
  224. 0
  225. ]
  226. },
  227. "class_type": "BasicScheduler"
  228. },
  229. "22": {
  230. "inputs": {
  231. "model": [
  232. "12",
  233. 0
  234. ],
  235. "conditioning": [
  236. "6",
  237. 0
  238. ]
  239. },
  240. "class_type": "BasicGuider"
  241. },
  242. "25": {
  243. "inputs": {
  244. "noise_seed": 778937779713005
  245. },
  246. "class_type": "RandomNoise"
  247. }
  248. }
  249. """
  250. def queue_prompt(prompt, client_id, base_url):
  251. log.info("queue_prompt")
  252. p = {"prompt": prompt, "client_id": client_id}
  253. data = json.dumps(p).encode("utf-8")
  254. req = urllib.request.Request(f"{base_url}/prompt", data=data)
  255. return json.loads(urllib.request.urlopen(req).read())
  256. def get_image(filename, subfolder, folder_type, base_url):
  257. log.info("get_image")
  258. data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
  259. url_values = urllib.parse.urlencode(data)
  260. with urllib.request.urlopen(f"{base_url}/view?{url_values}") as response:
  261. return response.read()
  262. def get_image_url(filename, subfolder, folder_type, base_url):
  263. log.info("get_image")
  264. data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
  265. url_values = urllib.parse.urlencode(data)
  266. return f"{base_url}/view?{url_values}"
  267. def get_history(prompt_id, base_url):
  268. log.info("get_history")
  269. with urllib.request.urlopen(f"{base_url}/history/{prompt_id}") as response:
  270. return json.loads(response.read())
  271. def get_images(ws, prompt, client_id, base_url):
  272. prompt_id = queue_prompt(prompt, client_id, base_url)["prompt_id"]
  273. output_images = []
  274. while True:
  275. out = ws.recv()
  276. if isinstance(out, str):
  277. message = json.loads(out)
  278. if message["type"] == "executing":
  279. data = message["data"]
  280. if data["node"] is None and data["prompt_id"] == prompt_id:
  281. break # Execution is done
  282. else:
  283. continue # previews are binary data
  284. history = get_history(prompt_id, base_url)[prompt_id]
  285. for o in history["outputs"]:
  286. for node_id in history["outputs"]:
  287. node_output = history["outputs"][node_id]
  288. if "images" in node_output:
  289. for image in node_output["images"]:
  290. url = get_image_url(
  291. image["filename"], image["subfolder"], image["type"], base_url
  292. )
  293. output_images.append({"url": url})
  294. return {"data": output_images}
  295. class ImageGenerationPayload(BaseModel):
  296. prompt: str
  297. negative_prompt: Optional[str] = ""
  298. steps: Optional[int] = None
  299. seed: Optional[int] = None
  300. width: int
  301. height: int
  302. n: int = 1
  303. cfg_scale: Optional[float] = None
  304. sampler: Optional[str] = None
  305. scheduler: Optional[str] = None
  306. sd3: Optional[bool] = None
  307. flux: Optional[bool] = None
  308. flux_weight_dtype: Optional[str] = None
  309. flux_fp8_clip: Optional[bool] = None
  310. def comfyui_generate_image(
  311. model: str, payload: ImageGenerationPayload, client_id, base_url
  312. ):
  313. ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://")
  314. comfyui_prompt = json.loads(COMFYUI_DEFAULT_PROMPT)
  315. if payload.cfg_scale:
  316. comfyui_prompt["3"]["inputs"]["cfg"] = payload.cfg_scale
  317. if payload.sampler:
  318. comfyui_prompt["3"]["inputs"]["sampler"] = payload.sampler
  319. if payload.scheduler:
  320. comfyui_prompt["3"]["inputs"]["scheduler"] = payload.scheduler
  321. if payload.sd3:
  322. comfyui_prompt["5"]["class_type"] = "EmptySD3LatentImage"
  323. if payload.steps:
  324. comfyui_prompt["3"]["inputs"]["steps"] = payload.steps
  325. comfyui_prompt["4"]["inputs"]["ckpt_name"] = model
  326. comfyui_prompt["7"]["inputs"]["text"] = payload.negative_prompt
  327. comfyui_prompt["3"]["inputs"]["seed"] = (
  328. payload.seed if payload.seed else random.randint(0, 18446744073709551614)
  329. )
  330. # as Flux uses a completely different workflow, we must treat it specially
  331. if payload.flux:
  332. comfyui_prompt = json.loads(FLUX_DEFAULT_PROMPT)
  333. comfyui_prompt["12"]["inputs"]["unet_name"] = model
  334. comfyui_prompt["25"]["inputs"]["noise_seed"] = (
  335. payload.seed if payload.seed else random.randint(0, 18446744073709551614)
  336. )
  337. if payload.sampler:
  338. comfyui_prompt["16"]["inputs"]["sampler_name"] = payload.sampler
  339. if payload.steps:
  340. comfyui_prompt["17"]["inputs"]["steps"] = payload.steps
  341. if payload.scheduler:
  342. comfyui_prompt["17"]["inputs"]["scheduler"] = payload.scheduler
  343. if payload.flux_weight_dtype:
  344. comfyui_prompt["12"]["inputs"]["weight_dtype"] = payload.flux_weight_dtype
  345. if payload.flux_fp8_clip:
  346. comfyui_prompt["11"]["inputs"][
  347. "clip_name2"
  348. ] = "t5xxl_fp8_e4m3fn.safetensors"
  349. comfyui_prompt["5"]["inputs"]["batch_size"] = payload.n
  350. comfyui_prompt["5"]["inputs"]["width"] = payload.width
  351. comfyui_prompt["5"]["inputs"]["height"] = payload.height
  352. # set the text prompt for our positive CLIPTextEncode
  353. comfyui_prompt["6"]["inputs"]["text"] = payload.prompt
  354. try:
  355. ws = websocket.WebSocket()
  356. ws.connect(f"{ws_url}/ws?clientId={client_id}")
  357. log.info("WebSocket connection established.")
  358. except Exception as e:
  359. log.exception(f"Failed to connect to WebSocket server: {e}")
  360. return None
  361. try:
  362. images = get_images(ws, comfyui_prompt, client_id, base_url)
  363. except Exception as e:
  364. log.exception(f"Error while receiving images: {e}")
  365. images = None
  366. ws.close()
  367. return images