images.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662
  1. import asyncio
  2. import base64
  3. import io
  4. import json
  5. import logging
  6. import mimetypes
  7. import re
  8. from pathlib import Path
  9. from typing import Optional
  10. import requests
  11. from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
  12. from open_webui.config import CACHE_DIR
  13. from open_webui.constants import ERROR_MESSAGES
  14. from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
  15. from open_webui.routers.files import upload_file
  16. from open_webui.utils.auth import get_admin_user, get_verified_user
  17. from open_webui.utils.images.comfyui import (
  18. ComfyUIGenerateImageForm,
  19. ComfyUIWorkflow,
  20. comfyui_generate_image,
  21. )
  22. from pydantic import BaseModel
  23. log = logging.getLogger(__name__)
  24. log.setLevel(SRC_LOG_LEVELS["IMAGES"])
  25. IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
  26. IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  27. router = APIRouter()
  28. @router.get("/config")
  29. async def get_config(request: Request, user=Depends(get_admin_user)):
  30. return {
  31. "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
  32. "engine": request.app.state.config.IMAGE_GENERATION_ENGINE,
  33. "prompt_generation": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION,
  34. "openai": {
  35. "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
  36. "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
  37. },
  38. "automatic1111": {
  39. "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
  40. "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
  41. "AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE,
  42. "AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER,
  43. "AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER,
  44. },
  45. "comfyui": {
  46. "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
  47. "COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY,
  48. "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
  49. "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
  50. },
  51. "gemini": {
  52. "GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
  53. "GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
  54. },
  55. }
  56. class OpenAIConfigForm(BaseModel):
  57. OPENAI_API_BASE_URL: str
  58. OPENAI_API_KEY: str
  59. class Automatic1111ConfigForm(BaseModel):
  60. AUTOMATIC1111_BASE_URL: str
  61. AUTOMATIC1111_API_AUTH: str
  62. AUTOMATIC1111_CFG_SCALE: Optional[str | float | int]
  63. AUTOMATIC1111_SAMPLER: Optional[str]
  64. AUTOMATIC1111_SCHEDULER: Optional[str]
  65. class ComfyUIConfigForm(BaseModel):
  66. COMFYUI_BASE_URL: str
  67. COMFYUI_API_KEY: str
  68. COMFYUI_WORKFLOW: str
  69. COMFYUI_WORKFLOW_NODES: list[dict]
  70. class GeminiConfigForm(BaseModel):
  71. GEMINI_API_BASE_URL: str
  72. GEMINI_API_KEY: str
  73. class ConfigForm(BaseModel):
  74. enabled: bool
  75. engine: str
  76. prompt_generation: bool
  77. openai: OpenAIConfigForm
  78. automatic1111: Automatic1111ConfigForm
  79. comfyui: ComfyUIConfigForm
  80. gemini: GeminiConfigForm
  81. @router.post("/config/update")
  82. async def update_config(
  83. request: Request, form_data: ConfigForm, user=Depends(get_admin_user)
  84. ):
  85. request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.engine
  86. request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.enabled
  87. request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = (
  88. form_data.prompt_generation
  89. )
  90. request.app.state.config.IMAGES_OPENAI_API_BASE_URL = (
  91. form_data.openai.OPENAI_API_BASE_URL
  92. )
  93. request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
  94. request.app.state.config.IMAGES_GEMINI_API_BASE_URL = (
  95. form_data.gemini.GEMINI_API_BASE_URL
  96. )
  97. request.app.state.config.IMAGES_GEMINI_API_KEY = form_data.gemini.GEMINI_API_KEY
  98. request.app.state.config.AUTOMATIC1111_BASE_URL = (
  99. form_data.automatic1111.AUTOMATIC1111_BASE_URL
  100. )
  101. request.app.state.config.AUTOMATIC1111_API_AUTH = (
  102. form_data.automatic1111.AUTOMATIC1111_API_AUTH
  103. )
  104. request.app.state.config.AUTOMATIC1111_CFG_SCALE = (
  105. float(form_data.automatic1111.AUTOMATIC1111_CFG_SCALE)
  106. if form_data.automatic1111.AUTOMATIC1111_CFG_SCALE
  107. else None
  108. )
  109. request.app.state.config.AUTOMATIC1111_SAMPLER = (
  110. form_data.automatic1111.AUTOMATIC1111_SAMPLER
  111. if form_data.automatic1111.AUTOMATIC1111_SAMPLER
  112. else None
  113. )
  114. request.app.state.config.AUTOMATIC1111_SCHEDULER = (
  115. form_data.automatic1111.AUTOMATIC1111_SCHEDULER
  116. if form_data.automatic1111.AUTOMATIC1111_SCHEDULER
  117. else None
  118. )
  119. request.app.state.config.COMFYUI_BASE_URL = (
  120. form_data.comfyui.COMFYUI_BASE_URL.strip("/")
  121. )
  122. request.app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
  123. request.app.state.config.COMFYUI_WORKFLOW_NODES = (
  124. form_data.comfyui.COMFYUI_WORKFLOW_NODES
  125. )
  126. return {
  127. "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
  128. "engine": request.app.state.config.IMAGE_GENERATION_ENGINE,
  129. "prompt_generation": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION,
  130. "openai": {
  131. "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
  132. "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
  133. },
  134. "automatic1111": {
  135. "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
  136. "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
  137. "AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE,
  138. "AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER,
  139. "AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER,
  140. },
  141. "comfyui": {
  142. "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
  143. "COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY,
  144. "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
  145. "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
  146. },
  147. "gemini": {
  148. "GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
  149. "GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
  150. },
  151. }
  152. def get_automatic1111_api_auth(request: Request):
  153. if request.app.state.config.AUTOMATIC1111_API_AUTH is None:
  154. return ""
  155. else:
  156. auth1111_byte_string = request.app.state.config.AUTOMATIC1111_API_AUTH.encode(
  157. "utf-8"
  158. )
  159. auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
  160. auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
  161. return f"Basic {auth1111_base64_encoded_string}"
  162. @router.get("/config/url/verify")
  163. async def verify_url(request: Request, user=Depends(get_admin_user)):
  164. if request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111":
  165. try:
  166. r = requests.get(
  167. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  168. headers={"authorization": get_automatic1111_api_auth(request)},
  169. )
  170. r.raise_for_status()
  171. return True
  172. except Exception:
  173. request.app.state.config.ENABLE_IMAGE_GENERATION = False
  174. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
  175. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
  176. try:
  177. r = requests.get(
  178. url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info"
  179. )
  180. r.raise_for_status()
  181. return True
  182. except Exception:
  183. request.app.state.config.ENABLE_IMAGE_GENERATION = False
  184. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
  185. else:
  186. return True
  187. def set_image_model(request: Request, model: str):
  188. log.info(f"Setting image model to {model}")
  189. request.app.state.config.IMAGE_GENERATION_MODEL = model
  190. if request.app.state.config.IMAGE_GENERATION_ENGINE in ["", "automatic1111"]:
  191. api_auth = get_automatic1111_api_auth(request)
  192. r = requests.get(
  193. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  194. headers={"authorization": api_auth},
  195. )
  196. options = r.json()
  197. if model != options["sd_model_checkpoint"]:
  198. options["sd_model_checkpoint"] = model
  199. r = requests.post(
  200. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  201. json=options,
  202. headers={"authorization": api_auth},
  203. )
  204. return request.app.state.config.IMAGE_GENERATION_MODEL
  205. def get_image_model(request):
  206. if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
  207. return (
  208. request.app.state.config.IMAGE_GENERATION_MODEL
  209. if request.app.state.config.IMAGE_GENERATION_MODEL
  210. else "dall-e-2"
  211. )
  212. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
  213. return (
  214. request.app.state.config.IMAGE_GENERATION_MODEL
  215. if request.app.state.config.IMAGE_GENERATION_MODEL
  216. else "imagen-3.0-generate-002"
  217. )
  218. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
  219. return (
  220. request.app.state.config.IMAGE_GENERATION_MODEL
  221. if request.app.state.config.IMAGE_GENERATION_MODEL
  222. else ""
  223. )
  224. elif (
  225. request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
  226. or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
  227. ):
  228. try:
  229. r = requests.get(
  230. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  231. headers={"authorization": get_automatic1111_api_auth(request)},
  232. )
  233. options = r.json()
  234. return options["sd_model_checkpoint"]
  235. except Exception as e:
  236. request.app.state.config.ENABLE_IMAGE_GENERATION = False
  237. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  238. class ImageConfigForm(BaseModel):
  239. MODEL: str
  240. IMAGE_SIZE: str
  241. IMAGE_STEPS: int
  242. @router.get("/image/config")
  243. async def get_image_config(request: Request, user=Depends(get_admin_user)):
  244. return {
  245. "MODEL": request.app.state.config.IMAGE_GENERATION_MODEL,
  246. "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
  247. "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
  248. }
  249. @router.post("/image/config/update")
  250. async def update_image_config(
  251. request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user)
  252. ):
  253. set_image_model(request, form_data.MODEL)
  254. pattern = r"^\d+x\d+$"
  255. if re.match(pattern, form_data.IMAGE_SIZE):
  256. request.app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
  257. else:
  258. raise HTTPException(
  259. status_code=400,
  260. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
  261. )
  262. if form_data.IMAGE_STEPS >= 0:
  263. request.app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS
  264. else:
  265. raise HTTPException(
  266. status_code=400,
  267. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
  268. )
  269. return {
  270. "MODEL": request.app.state.config.IMAGE_GENERATION_MODEL,
  271. "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
  272. "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
  273. }
  274. @router.get("/models")
  275. def get_models(request: Request, user=Depends(get_verified_user)):
  276. try:
  277. if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
  278. return [
  279. {"id": "dall-e-2", "name": "DALL·E 2"},
  280. {"id": "dall-e-3", "name": "DALL·E 3"},
  281. ]
  282. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
  283. return [
  284. {"id": "imagen-3-0-generate-002", "name": "imagen-3.0 generate-002"},
  285. ]
  286. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
  287. # TODO - get models from comfyui
  288. headers = {
  289. "Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
  290. }
  291. r = requests.get(
  292. url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info",
  293. headers=headers,
  294. )
  295. info = r.json()
  296. workflow = json.loads(request.app.state.config.COMFYUI_WORKFLOW)
  297. model_node_id = None
  298. for node in request.app.state.config.COMFYUI_WORKFLOW_NODES:
  299. if node["type"] == "model":
  300. if node["node_ids"]:
  301. model_node_id = node["node_ids"][0]
  302. break
  303. if model_node_id:
  304. model_list_key = None
  305. print(workflow[model_node_id]["class_type"])
  306. for key in info[workflow[model_node_id]["class_type"]]["input"][
  307. "required"
  308. ]:
  309. if "_name" in key:
  310. model_list_key = key
  311. break
  312. if model_list_key:
  313. return list(
  314. map(
  315. lambda model: {"id": model, "name": model},
  316. info[workflow[model_node_id]["class_type"]]["input"][
  317. "required"
  318. ][model_list_key][0],
  319. )
  320. )
  321. else:
  322. return list(
  323. map(
  324. lambda model: {"id": model, "name": model},
  325. info["CheckpointLoaderSimple"]["input"]["required"][
  326. "ckpt_name"
  327. ][0],
  328. )
  329. )
  330. elif (
  331. request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
  332. or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
  333. ):
  334. r = requests.get(
  335. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
  336. headers={"authorization": get_automatic1111_api_auth(request)},
  337. )
  338. models = r.json()
  339. return list(
  340. map(
  341. lambda model: {"id": model["title"], "name": model["model_name"]},
  342. models,
  343. )
  344. )
  345. except Exception as e:
  346. request.app.state.config.ENABLE_IMAGE_GENERATION = False
  347. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  348. class GenerateImageForm(BaseModel):
  349. model: Optional[str] = None
  350. prompt: str
  351. size: Optional[str] = None
  352. n: int = 1
  353. negative_prompt: Optional[str] = None
  354. def load_b64_image_data(b64_str):
  355. try:
  356. if "," in b64_str:
  357. header, encoded = b64_str.split(",", 1)
  358. mime_type = header.split(";")[0]
  359. img_data = base64.b64decode(encoded)
  360. else:
  361. mime_type = "image/png"
  362. img_data = base64.b64decode(b64_str)
  363. return img_data, mime_type
  364. except Exception as e:
  365. log.exception(f"Error loading image data: {e}")
  366. return None
  367. def load_url_image_data(url, headers=None):
  368. try:
  369. if headers:
  370. r = requests.get(url, headers=headers)
  371. else:
  372. r = requests.get(url)
  373. r.raise_for_status()
  374. if r.headers["content-type"].split("/")[0] == "image":
  375. mime_type = r.headers["content-type"]
  376. return r.content, mime_type
  377. else:
  378. log.error("Url does not point to an image.")
  379. return None
  380. except Exception as e:
  381. log.exception(f"Error saving image: {e}")
  382. return None
  383. def upload_image(request, image_metadata, image_data, content_type, user):
  384. image_format = mimetypes.guess_extension(content_type)
  385. file = UploadFile(
  386. file=io.BytesIO(image_data),
  387. filename=f"generated-image{image_format}", # will be converted to a unique ID on upload_file
  388. headers={
  389. "content-type": content_type,
  390. },
  391. )
  392. file_item = upload_file(request, file, user, file_metadata=image_metadata)
  393. url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
  394. return url
  395. @router.post("/generations")
  396. async def image_generations(
  397. request: Request,
  398. form_data: GenerateImageForm,
  399. user=Depends(get_verified_user),
  400. ):
  401. width, height = tuple(map(int, request.app.state.config.IMAGE_SIZE.split("x")))
  402. r = None
  403. try:
  404. if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
  405. headers = {}
  406. headers["Authorization"] = (
  407. f"Bearer {request.app.state.config.IMAGES_OPENAI_API_KEY}"
  408. )
  409. headers["Content-Type"] = "application/json"
  410. if ENABLE_FORWARD_USER_INFO_HEADERS:
  411. headers["X-OpenWebUI-User-Name"] = user.name
  412. headers["X-OpenWebUI-User-Id"] = user.id
  413. headers["X-OpenWebUI-User-Email"] = user.email
  414. headers["X-OpenWebUI-User-Role"] = user.role
  415. data = {
  416. "model": (
  417. request.app.state.config.IMAGE_GENERATION_MODEL
  418. if request.app.state.config.IMAGE_GENERATION_MODEL != ""
  419. else "dall-e-2"
  420. ),
  421. "prompt": form_data.prompt,
  422. "n": form_data.n,
  423. "size": (
  424. form_data.size
  425. if form_data.size
  426. else request.app.state.config.IMAGE_SIZE
  427. ),
  428. "response_format": "b64_json",
  429. }
  430. # Use asyncio.to_thread for the requests.post call
  431. r = await asyncio.to_thread(
  432. requests.post,
  433. url=f"{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/generations",
  434. json=data,
  435. headers=headers,
  436. )
  437. r.raise_for_status()
  438. res = r.json()
  439. images = []
  440. for image in res["data"]:
  441. image_data, content_type = load_b64_image_data(image["b64_json"])
  442. url = upload_image(request, data, image_data, content_type, user)
  443. images.append({"url": url})
  444. return images
  445. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
  446. headers = {}
  447. headers["Content-Type"] = "application/json"
  448. headers["x-goog-api-key"] = request.app.state.config.IMAGES_GEMINI_API_KEY
  449. model = get_image_model(request)
  450. data = {
  451. "instances": {"prompt": form_data.prompt},
  452. "parameters": {
  453. "sampleCount": form_data.n,
  454. "outputOptions": {"mimeType": "image/png"},
  455. },
  456. }
  457. # Use asyncio.to_thread for the requests.post call
  458. r = await asyncio.to_thread(
  459. requests.post,
  460. url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}:predict",
  461. json=data,
  462. headers=headers,
  463. )
  464. r.raise_for_status()
  465. res = r.json()
  466. images = []
  467. for image in res["predictions"]:
  468. image_data, content_type = load_b64_image_data(
  469. image["bytesBase64Encoded"]
  470. )
  471. url = upload_image(request, data, image_data, content_type, user)
  472. images.append({"url": url})
  473. return images
  474. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
  475. data = {
  476. "prompt": form_data.prompt,
  477. "width": width,
  478. "height": height,
  479. "n": form_data.n,
  480. }
  481. if request.app.state.config.IMAGE_STEPS is not None:
  482. data["steps"] = request.app.state.config.IMAGE_STEPS
  483. if form_data.negative_prompt is not None:
  484. data["negative_prompt"] = form_data.negative_prompt
  485. form_data = ComfyUIGenerateImageForm(
  486. **{
  487. "workflow": ComfyUIWorkflow(
  488. **{
  489. "workflow": request.app.state.config.COMFYUI_WORKFLOW,
  490. "nodes": request.app.state.config.COMFYUI_WORKFLOW_NODES,
  491. }
  492. ),
  493. **data,
  494. }
  495. )
  496. res = await comfyui_generate_image(
  497. request.app.state.config.IMAGE_GENERATION_MODEL,
  498. form_data,
  499. user.id,
  500. request.app.state.config.COMFYUI_BASE_URL,
  501. request.app.state.config.COMFYUI_API_KEY,
  502. )
  503. log.debug(f"res: {res}")
  504. images = []
  505. for image in res["data"]:
  506. headers = None
  507. if request.app.state.config.COMFYUI_API_KEY:
  508. headers = {
  509. "Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
  510. }
  511. image_data, content_type = load_url_image_data(image["url"], headers)
  512. url = upload_image(
  513. request,
  514. form_data.model_dump(exclude_none=True),
  515. image_data,
  516. content_type,
  517. user,
  518. )
  519. images.append({"url": url})
  520. return images
  521. elif (
  522. request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
  523. or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
  524. ):
  525. if form_data.model:
  526. set_image_model(form_data.model)
  527. data = {
  528. "prompt": form_data.prompt,
  529. "batch_size": form_data.n,
  530. "width": width,
  531. "height": height,
  532. }
  533. if request.app.state.config.IMAGE_STEPS is not None:
  534. data["steps"] = request.app.state.config.IMAGE_STEPS
  535. if form_data.negative_prompt is not None:
  536. data["negative_prompt"] = form_data.negative_prompt
  537. if request.app.state.config.AUTOMATIC1111_CFG_SCALE:
  538. data["cfg_scale"] = request.app.state.config.AUTOMATIC1111_CFG_SCALE
  539. if request.app.state.config.AUTOMATIC1111_SAMPLER:
  540. data["sampler_name"] = request.app.state.config.AUTOMATIC1111_SAMPLER
  541. if request.app.state.config.AUTOMATIC1111_SCHEDULER:
  542. data["scheduler"] = request.app.state.config.AUTOMATIC1111_SCHEDULER
  543. # Use asyncio.to_thread for the requests.post call
  544. r = await asyncio.to_thread(
  545. requests.post,
  546. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
  547. json=data,
  548. headers={"authorization": get_automatic1111_api_auth(request)},
  549. )
  550. res = r.json()
  551. log.debug(f"res: {res}")
  552. images = []
  553. for image in res["images"]:
  554. image_data, content_type = load_b64_image_data(image)
  555. url = upload_image(
  556. request,
  557. {**data, "info": res["info"]},
  558. image_data,
  559. content_type,
  560. user,
  561. )
  562. images.append({"url": url})
  563. return images
  564. except Exception as e:
  565. error = e
  566. if r != None:
  567. data = r.json()
  568. if "error" in data:
  569. error = data["error"]["message"]
  570. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))