comfyui.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
  2. import uuid
  3. import json
  4. import urllib.request
  5. import urllib.parse
  6. import random
  7. from pydantic import BaseModel
  8. from typing import Optional
  9. COMFYUI_DEFAULT_PROMPT = """
  10. {
  11. "3": {
  12. "inputs": {
  13. "seed": 0,
  14. "steps": 20,
  15. "cfg": 8,
  16. "sampler_name": "euler",
  17. "scheduler": "normal",
  18. "denoise": 1,
  19. "model": [
  20. "4",
  21. 0
  22. ],
  23. "positive": [
  24. "6",
  25. 0
  26. ],
  27. "negative": [
  28. "7",
  29. 0
  30. ],
  31. "latent_image": [
  32. "5",
  33. 0
  34. ]
  35. },
  36. "class_type": "KSampler",
  37. "_meta": {
  38. "title": "KSampler"
  39. }
  40. },
  41. "4": {
  42. "inputs": {
  43. "ckpt_name": "model.safetensors"
  44. },
  45. "class_type": "CheckpointLoaderSimple",
  46. "_meta": {
  47. "title": "Load Checkpoint"
  48. }
  49. },
  50. "5": {
  51. "inputs": {
  52. "width": 512,
  53. "height": 512,
  54. "batch_size": 1
  55. },
  56. "class_type": "EmptyLatentImage",
  57. "_meta": {
  58. "title": "Empty Latent Image"
  59. }
  60. },
  61. "6": {
  62. "inputs": {
  63. "text": "Prompt",
  64. "clip": [
  65. "4",
  66. 1
  67. ]
  68. },
  69. "class_type": "CLIPTextEncode",
  70. "_meta": {
  71. "title": "CLIP Text Encode (Prompt)"
  72. }
  73. },
  74. "7": {
  75. "inputs": {
  76. "text": "Negative Prompt",
  77. "clip": [
  78. "4",
  79. 1
  80. ]
  81. },
  82. "class_type": "CLIPTextEncode",
  83. "_meta": {
  84. "title": "CLIP Text Encode (Prompt)"
  85. }
  86. },
  87. "8": {
  88. "inputs": {
  89. "samples": [
  90. "3",
  91. 0
  92. ],
  93. "vae": [
  94. "4",
  95. 2
  96. ]
  97. },
  98. "class_type": "VAEDecode",
  99. "_meta": {
  100. "title": "VAE Decode"
  101. }
  102. },
  103. "9": {
  104. "inputs": {
  105. "filename_prefix": "ComfyUI",
  106. "images": [
  107. "8",
  108. 0
  109. ]
  110. },
  111. "class_type": "SaveImage",
  112. "_meta": {
  113. "title": "Save Image"
  114. }
  115. }
  116. }
  117. """
  118. def queue_prompt(prompt, client_id, base_url):
  119. print("queue_prompt")
  120. p = {"prompt": prompt, "client_id": client_id}
  121. data = json.dumps(p).encode("utf-8")
  122. req = urllib.request.Request(f"{base_url}/prompt", data=data)
  123. return json.loads(urllib.request.urlopen(req).read())
  124. def get_image(filename, subfolder, folder_type, base_url):
  125. print("get_image")
  126. data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
  127. url_values = urllib.parse.urlencode(data)
  128. with urllib.request.urlopen(f"{base_url}/view?{url_values}") as response:
  129. return response.read()
  130. def get_image_url(filename, subfolder, folder_type, base_url):
  131. print("get_image")
  132. data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
  133. url_values = urllib.parse.urlencode(data)
  134. return f"{base_url}/view?{url_values}"
  135. def get_history(prompt_id, base_url):
  136. print("get_history")
  137. with urllib.request.urlopen(f"{base_url}/history/{prompt_id}") as response:
  138. return json.loads(response.read())
  139. def get_images(ws, prompt, client_id, base_url):
  140. prompt_id = queue_prompt(prompt, client_id, base_url)["prompt_id"]
  141. output_images = []
  142. while True:
  143. out = ws.recv()
  144. if isinstance(out, str):
  145. message = json.loads(out)
  146. if message["type"] == "executing":
  147. data = message["data"]
  148. if data["node"] is None and data["prompt_id"] == prompt_id:
  149. break # Execution is done
  150. else:
  151. continue # previews are binary data
  152. history = get_history(prompt_id, base_url)[prompt_id]
  153. for o in history["outputs"]:
  154. for node_id in history["outputs"]:
  155. node_output = history["outputs"][node_id]
  156. if "images" in node_output:
  157. for image in node_output["images"]:
  158. url = get_image_url(
  159. image["filename"], image["subfolder"], image["type"], base_url
  160. )
  161. output_images.append({"url": url})
  162. return {"data": output_images}
  163. class ImageGenerationPayload(BaseModel):
  164. prompt: str
  165. negative_prompt: Optional[str] = ""
  166. steps: Optional[int] = None
  167. seed: Optional[int] = None
  168. width: int
  169. height: int
  170. n: int = 1
  171. def comfyui_generate_image(
  172. model: str, payload: ImageGenerationPayload, client_id, base_url
  173. ):
  174. host = base_url.replace("http://", "").replace("https://", "")
  175. comfyui_prompt = json.loads(COMFYUI_DEFAULT_PROMPT)
  176. comfyui_prompt["4"]["inputs"]["ckpt_name"] = model
  177. comfyui_prompt["5"]["inputs"]["batch_size"] = payload.n
  178. comfyui_prompt["5"]["inputs"]["width"] = payload.width
  179. comfyui_prompt["5"]["inputs"]["height"] = payload.height
  180. # set the text prompt for our positive CLIPTextEncode
  181. comfyui_prompt["6"]["inputs"]["text"] = payload.prompt
  182. comfyui_prompt["7"]["inputs"]["text"] = payload.negative_prompt
  183. if payload.steps:
  184. comfyui_prompt["3"]["inputs"]["steps"] = payload.steps
  185. comfyui_prompt["3"]["inputs"]["seed"] = (
  186. payload.seed if payload.seed else random.randint(0, 18446744073709551614)
  187. )
  188. try:
  189. ws = websocket.WebSocket()
  190. ws.connect(f"ws://{host}/ws?clientId={client_id}")
  191. print("WebSocket connection established.")
  192. except Exception as e:
  193. print(f"Failed to connect to WebSocket server: {e}")
  194. return None
  195. try:
  196. images = get_images(ws, comfyui_prompt, client_id, base_url)
  197. except Exception as e:
  198. print(f"Error while receiving images: {e}")
  199. images = None
  200. ws.close()
  201. return images