comfyui.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. import asyncio
  2. import json
  3. import logging
  4. import random
  5. import urllib.parse
  6. import urllib.request
  7. from typing import Optional
  8. import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
  9. from open_webui.env import SRC_LOG_LEVELS
  10. from pydantic import BaseModel
  11. log = logging.getLogger(__name__)
  12. log.setLevel(SRC_LOG_LEVELS["COMFYUI"])
  13. default_headers = {"User-Agent": "Mozilla/5.0"}
  14. def queue_prompt(prompt, client_id, base_url, api_key):
  15. log.info("queue_prompt")
  16. p = {"prompt": prompt, "client_id": client_id}
  17. data = json.dumps(p).encode("utf-8")
  18. log.debug(f"queue_prompt data: {data}")
  19. try:
  20. req = urllib.request.Request(
  21. f"{base_url}/prompt",
  22. data=data,
  23. headers={**default_headers, "Authorization": f"Bearer {api_key}"},
  24. )
  25. response = urllib.request.urlopen(req).read()
  26. return json.loads(response)
  27. except Exception as e:
  28. log.exception(f"Error while queuing prompt: {e}")
  29. raise e
  30. def get_image(filename, subfolder, folder_type, base_url, api_key):
  31. log.info("get_image")
  32. data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
  33. url_values = urllib.parse.urlencode(data)
  34. req = urllib.request.Request(
  35. f"{base_url}/view?{url_values}",
  36. headers={**default_headers, "Authorization": f"Bearer {api_key}"},
  37. )
  38. with urllib.request.urlopen(req) as response:
  39. return response.read()
  40. def get_image_url(filename, subfolder, folder_type, base_url):
  41. log.info("get_image")
  42. data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
  43. url_values = urllib.parse.urlencode(data)
  44. return f"{base_url}/view?{url_values}"
  45. def get_history(prompt_id, base_url, api_key):
  46. log.info("get_history")
  47. req = urllib.request.Request(
  48. f"{base_url}/history/{prompt_id}",
  49. headers={**default_headers, "Authorization": f"Bearer {api_key}"},
  50. )
  51. with urllib.request.urlopen(req) as response:
  52. return json.loads(response.read())
  53. def get_images(ws, prompt, client_id, base_url, api_key):
  54. prompt_id = queue_prompt(prompt, client_id, base_url, api_key)["prompt_id"]
  55. output_images = []
  56. while True:
  57. out = ws.recv()
  58. if isinstance(out, str):
  59. message = json.loads(out)
  60. if message["type"] == "executing":
  61. data = message["data"]
  62. if data["node"] is None and data["prompt_id"] == prompt_id:
  63. break # Execution is done
  64. else:
  65. continue # previews are binary data
  66. history = get_history(prompt_id, base_url, api_key)[prompt_id]
  67. for o in history["outputs"]:
  68. for node_id in history["outputs"]:
  69. node_output = history["outputs"][node_id]
  70. if "images" in node_output:
  71. for image in node_output["images"]:
  72. url = get_image_url(
  73. image["filename"], image["subfolder"], image["type"], base_url
  74. )
  75. output_images.append({"url": url})
  76. return {"data": output_images}
  77. class ComfyUINodeInput(BaseModel):
  78. type: Optional[str] = None
  79. node_ids: list[str] = []
  80. key: Optional[str] = "text"
  81. value: Optional[str] = None
  82. class ComfyUIWorkflow(BaseModel):
  83. workflow: str
  84. nodes: list[ComfyUINodeInput]
  85. class ComfyUIGenerateImageForm(BaseModel):
  86. workflow: ComfyUIWorkflow
  87. prompt: str
  88. negative_prompt: Optional[str] = None
  89. width: int
  90. height: int
  91. n: int = 1
  92. steps: Optional[int] = None
  93. seed: Optional[int] = None
  94. async def comfyui_generate_image(
  95. model: str, payload: ComfyUIGenerateImageForm, client_id, base_url, api_key
  96. ):
  97. ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://")
  98. workflow = json.loads(payload.workflow.workflow)
  99. for node in payload.workflow.nodes:
  100. if node.type:
  101. if node.type == "model":
  102. for node_id in node.node_ids:
  103. workflow[node_id]["inputs"][node.key] = model
  104. elif node.type == "prompt":
  105. for node_id in node.node_ids:
  106. workflow[node_id]["inputs"][
  107. node.key if node.key else "text"
  108. ] = payload.prompt
  109. elif node.type == "negative_prompt":
  110. for node_id in node.node_ids:
  111. workflow[node_id]["inputs"][
  112. node.key if node.key else "text"
  113. ] = payload.negative_prompt
  114. elif node.type == "width":
  115. for node_id in node.node_ids:
  116. workflow[node_id]["inputs"][
  117. node.key if node.key else "width"
  118. ] = payload.width
  119. elif node.type == "height":
  120. for node_id in node.node_ids:
  121. workflow[node_id]["inputs"][
  122. node.key if node.key else "height"
  123. ] = payload.height
  124. elif node.type == "n":
  125. for node_id in node.node_ids:
  126. workflow[node_id]["inputs"][
  127. node.key if node.key else "batch_size"
  128. ] = payload.n
  129. elif node.type == "steps":
  130. for node_id in node.node_ids:
  131. workflow[node_id]["inputs"][
  132. node.key if node.key else "steps"
  133. ] = payload.steps
  134. elif node.type == "seed":
  135. seed = (
  136. payload.seed
  137. if payload.seed
  138. else random.randint(0, 18446744073709551614)
  139. )
  140. for node_id in node.node_ids:
  141. workflow[node_id]["inputs"][node.key] = seed
  142. else:
  143. for node_id in node.node_ids:
  144. workflow[node_id]["inputs"][node.key] = node.value
  145. try:
  146. ws = websocket.WebSocket()
  147. headers = {"Authorization": f"Bearer {api_key}"}
  148. ws.connect(f"{ws_url}/ws?clientId={client_id}", header=headers)
  149. log.info("WebSocket connection established.")
  150. except Exception as e:
  151. log.exception(f"Failed to connect to WebSocket server: {e}")
  152. return None
  153. try:
  154. log.info("Sending workflow to WebSocket server.")
  155. log.info(f"Workflow: {workflow}")
  156. images = await asyncio.to_thread(
  157. get_images, ws, workflow, client_id, base_url, api_key
  158. )
  159. except Exception as e:
  160. log.exception(f"Error while receiving images: {e}")
  161. images = None
  162. ws.close()
  163. return images