images.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678
  1. import asyncio
  2. import base64
  3. import io
  4. import json
  5. import logging
  6. import mimetypes
  7. import re
  8. from pathlib import Path
  9. from typing import Optional
  10. import requests
  11. from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
  12. from open_webui.config import CACHE_DIR
  13. from open_webui.constants import ERROR_MESSAGES
  14. from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
  15. from open_webui.routers.files import upload_file
  16. from open_webui.utils.auth import get_admin_user, get_verified_user
  17. from open_webui.utils.images.comfyui import (
  18. ComfyUIGenerateImageForm,
  19. ComfyUIWorkflow,
  20. comfyui_generate_image,
  21. )
  22. from pydantic import BaseModel
  23. log = logging.getLogger(__name__)
  24. log.setLevel(SRC_LOG_LEVELS["IMAGES"])
  25. IMAGE_CACHE_DIR = CACHE_DIR / "image" / "generations"
  26. IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  27. router = APIRouter()
  28. @router.get("/config")
  29. async def get_config(request: Request, user=Depends(get_admin_user)):
  30. return {
  31. "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
  32. "engine": request.app.state.config.IMAGE_GENERATION_ENGINE,
  33. "prompt_generation": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION,
  34. "openai": {
  35. "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
  36. "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
  37. },
  38. "automatic1111": {
  39. "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
  40. "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
  41. "AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE,
  42. "AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER,
  43. "AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER,
  44. },
  45. "comfyui": {
  46. "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
  47. "COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY,
  48. "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
  49. "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
  50. },
  51. "gemini": {
  52. "GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
  53. "GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
  54. },
  55. }
  56. class OpenAIConfigForm(BaseModel):
  57. OPENAI_API_BASE_URL: str
  58. OPENAI_API_KEY: str
  59. class Automatic1111ConfigForm(BaseModel):
  60. AUTOMATIC1111_BASE_URL: str
  61. AUTOMATIC1111_API_AUTH: str
  62. AUTOMATIC1111_CFG_SCALE: Optional[str | float | int]
  63. AUTOMATIC1111_SAMPLER: Optional[str]
  64. AUTOMATIC1111_SCHEDULER: Optional[str]
  65. class ComfyUIConfigForm(BaseModel):
  66. COMFYUI_BASE_URL: str
  67. COMFYUI_API_KEY: str
  68. COMFYUI_WORKFLOW: str
  69. COMFYUI_WORKFLOW_NODES: list[dict]
  70. class GeminiConfigForm(BaseModel):
  71. GEMINI_API_BASE_URL: str
  72. GEMINI_API_KEY: str
  73. class ConfigForm(BaseModel):
  74. enabled: bool
  75. engine: str
  76. prompt_generation: bool
  77. openai: OpenAIConfigForm
  78. automatic1111: Automatic1111ConfigForm
  79. comfyui: ComfyUIConfigForm
  80. gemini: GeminiConfigForm
  81. @router.post("/config/update")
  82. async def update_config(
  83. request: Request, form_data: ConfigForm, user=Depends(get_admin_user)
  84. ):
  85. request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.engine
  86. request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.enabled
  87. request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = (
  88. form_data.prompt_generation
  89. )
  90. request.app.state.config.IMAGES_OPENAI_API_BASE_URL = (
  91. form_data.openai.OPENAI_API_BASE_URL
  92. )
  93. request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
  94. request.app.state.config.IMAGES_GEMINI_API_BASE_URL = (
  95. form_data.gemini.GEMINI_API_BASE_URL
  96. )
  97. request.app.state.config.IMAGES_GEMINI_API_KEY = form_data.gemini.GEMINI_API_KEY
  98. request.app.state.config.AUTOMATIC1111_BASE_URL = (
  99. form_data.automatic1111.AUTOMATIC1111_BASE_URL
  100. )
  101. request.app.state.config.AUTOMATIC1111_API_AUTH = (
  102. form_data.automatic1111.AUTOMATIC1111_API_AUTH
  103. )
  104. request.app.state.config.AUTOMATIC1111_CFG_SCALE = (
  105. float(form_data.automatic1111.AUTOMATIC1111_CFG_SCALE)
  106. if form_data.automatic1111.AUTOMATIC1111_CFG_SCALE
  107. else None
  108. )
  109. request.app.state.config.AUTOMATIC1111_SAMPLER = (
  110. form_data.automatic1111.AUTOMATIC1111_SAMPLER
  111. if form_data.automatic1111.AUTOMATIC1111_SAMPLER
  112. else None
  113. )
  114. request.app.state.config.AUTOMATIC1111_SCHEDULER = (
  115. form_data.automatic1111.AUTOMATIC1111_SCHEDULER
  116. if form_data.automatic1111.AUTOMATIC1111_SCHEDULER
  117. else None
  118. )
  119. request.app.state.config.COMFYUI_BASE_URL = (
  120. form_data.comfyui.COMFYUI_BASE_URL.strip("/")
  121. )
  122. request.app.state.config.COMFYUI_API_KEY = form_data.comfyui.COMFYUI_API_KEY
  123. request.app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
  124. request.app.state.config.COMFYUI_WORKFLOW_NODES = (
  125. form_data.comfyui.COMFYUI_WORKFLOW_NODES
  126. )
  127. return {
  128. "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
  129. "engine": request.app.state.config.IMAGE_GENERATION_ENGINE,
  130. "prompt_generation": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION,
  131. "openai": {
  132. "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
  133. "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
  134. },
  135. "automatic1111": {
  136. "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
  137. "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
  138. "AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE,
  139. "AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER,
  140. "AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER,
  141. },
  142. "comfyui": {
  143. "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
  144. "COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY,
  145. "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
  146. "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
  147. },
  148. "gemini": {
  149. "GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
  150. "GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
  151. },
  152. }
  153. def get_automatic1111_api_auth(request: Request):
  154. if request.app.state.config.AUTOMATIC1111_API_AUTH is None:
  155. return ""
  156. else:
  157. auth1111_byte_string = request.app.state.config.AUTOMATIC1111_API_AUTH.encode(
  158. "utf-8"
  159. )
  160. auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
  161. auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
  162. return f"Basic {auth1111_base64_encoded_string}"
  163. @router.get("/config/url/verify")
  164. async def verify_url(request: Request, user=Depends(get_admin_user)):
  165. if request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111":
  166. try:
  167. r = requests.get(
  168. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  169. headers={"authorization": get_automatic1111_api_auth(request)},
  170. )
  171. r.raise_for_status()
  172. return True
  173. except Exception:
  174. request.app.state.config.ENABLE_IMAGE_GENERATION = False
  175. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
  176. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
  177. headers = None
  178. if request.app.state.config.COMFYUI_API_KEY:
  179. headers = {
  180. "Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
  181. }
  182. try:
  183. r = requests.get(
  184. url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info",
  185. headers=headers,
  186. )
  187. r.raise_for_status()
  188. return True
  189. except Exception:
  190. request.app.state.config.ENABLE_IMAGE_GENERATION = False
  191. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
  192. else:
  193. return True
  194. def set_image_model(request: Request, model: str):
  195. log.info(f"Setting image model to {model}")
  196. request.app.state.config.IMAGE_GENERATION_MODEL = model
  197. if request.app.state.config.IMAGE_GENERATION_ENGINE in ["", "automatic1111"]:
  198. api_auth = get_automatic1111_api_auth(request)
  199. r = requests.get(
  200. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  201. headers={"authorization": api_auth},
  202. )
  203. options = r.json()
  204. if model != options["sd_model_checkpoint"]:
  205. options["sd_model_checkpoint"] = model
  206. r = requests.post(
  207. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  208. json=options,
  209. headers={"authorization": api_auth},
  210. )
  211. return request.app.state.config.IMAGE_GENERATION_MODEL
  212. def get_image_model(request):
  213. if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
  214. return (
  215. request.app.state.config.IMAGE_GENERATION_MODEL
  216. if request.app.state.config.IMAGE_GENERATION_MODEL
  217. else "dall-e-2"
  218. )
  219. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
  220. return (
  221. request.app.state.config.IMAGE_GENERATION_MODEL
  222. if request.app.state.config.IMAGE_GENERATION_MODEL
  223. else "imagen-3.0-generate-002"
  224. )
  225. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
  226. return (
  227. request.app.state.config.IMAGE_GENERATION_MODEL
  228. if request.app.state.config.IMAGE_GENERATION_MODEL
  229. else ""
  230. )
  231. elif (
  232. request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
  233. or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
  234. ):
  235. try:
  236. r = requests.get(
  237. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  238. headers={"authorization": get_automatic1111_api_auth(request)},
  239. )
  240. options = r.json()
  241. return options["sd_model_checkpoint"]
  242. except Exception as e:
  243. request.app.state.config.ENABLE_IMAGE_GENERATION = False
  244. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  245. class ImageConfigForm(BaseModel):
  246. MODEL: str
  247. IMAGE_SIZE: str
  248. IMAGE_STEPS: int
  249. @router.get("/image/config")
  250. async def get_image_config(request: Request, user=Depends(get_admin_user)):
  251. return {
  252. "MODEL": request.app.state.config.IMAGE_GENERATION_MODEL,
  253. "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
  254. "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
  255. }
  256. @router.post("/image/config/update")
  257. async def update_image_config(
  258. request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user)
  259. ):
  260. set_image_model(request, form_data.MODEL)
  261. pattern = r"^\d+x\d+$"
  262. if re.match(pattern, form_data.IMAGE_SIZE):
  263. request.app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
  264. else:
  265. raise HTTPException(
  266. status_code=400,
  267. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
  268. )
  269. if form_data.IMAGE_STEPS >= 0:
  270. request.app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS
  271. else:
  272. raise HTTPException(
  273. status_code=400,
  274. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
  275. )
  276. return {
  277. "MODEL": request.app.state.config.IMAGE_GENERATION_MODEL,
  278. "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
  279. "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
  280. }
  281. @router.get("/models")
  282. def get_models(request: Request, user=Depends(get_verified_user)):
  283. try:
  284. if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
  285. return [
  286. {"id": "dall-e-2", "name": "DALL·E 2"},
  287. {"id": "dall-e-3", "name": "DALL·E 3"},
  288. ]
  289. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
  290. return [
  291. {"id": "imagen-3-0-generate-002", "name": "imagen-3.0 generate-002"},
  292. ]
  293. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
  294. # TODO - get models from comfyui
  295. headers = {
  296. "Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
  297. }
  298. r = requests.get(
  299. url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info",
  300. headers=headers,
  301. )
  302. info = r.json()
  303. workflow = json.loads(request.app.state.config.COMFYUI_WORKFLOW)
  304. model_node_id = None
  305. for node in request.app.state.config.COMFYUI_WORKFLOW_NODES:
  306. if node["type"] == "model":
  307. if node["node_ids"]:
  308. model_node_id = node["node_ids"][0]
  309. break
  310. if model_node_id:
  311. model_list_key = None
  312. log.info(workflow[model_node_id]["class_type"])
  313. for key in info[workflow[model_node_id]["class_type"]]["input"][
  314. "required"
  315. ]:
  316. if "_name" in key:
  317. model_list_key = key
  318. break
  319. if model_list_key:
  320. return list(
  321. map(
  322. lambda model: {"id": model, "name": model},
  323. info[workflow[model_node_id]["class_type"]]["input"][
  324. "required"
  325. ][model_list_key][0],
  326. )
  327. )
  328. else:
  329. return list(
  330. map(
  331. lambda model: {"id": model, "name": model},
  332. info["CheckpointLoaderSimple"]["input"]["required"][
  333. "ckpt_name"
  334. ][0],
  335. )
  336. )
  337. elif (
  338. request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
  339. or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
  340. ):
  341. r = requests.get(
  342. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
  343. headers={"authorization": get_automatic1111_api_auth(request)},
  344. )
  345. models = r.json()
  346. return list(
  347. map(
  348. lambda model: {"id": model["title"], "name": model["model_name"]},
  349. models,
  350. )
  351. )
  352. except Exception as e:
  353. request.app.state.config.ENABLE_IMAGE_GENERATION = False
  354. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  355. class GenerateImageForm(BaseModel):
  356. model: Optional[str] = None
  357. prompt: str
  358. size: Optional[str] = None
  359. n: int = 1
  360. negative_prompt: Optional[str] = None
  361. def load_b64_image_data(b64_str):
  362. try:
  363. if "," in b64_str:
  364. header, encoded = b64_str.split(",", 1)
  365. mime_type = header.split(";")[0]
  366. img_data = base64.b64decode(encoded)
  367. else:
  368. mime_type = "image/png"
  369. img_data = base64.b64decode(b64_str)
  370. return img_data, mime_type
  371. except Exception as e:
  372. log.exception(f"Error loading image data: {e}")
  373. return None
  374. def load_url_image_data(url, headers=None):
  375. try:
  376. if headers:
  377. r = requests.get(url, headers=headers)
  378. else:
  379. r = requests.get(url)
  380. r.raise_for_status()
  381. if r.headers["content-type"].split("/")[0] == "image":
  382. mime_type = r.headers["content-type"]
  383. return r.content, mime_type
  384. else:
  385. log.error("Url does not point to an image.")
  386. return None
  387. except Exception as e:
  388. log.exception(f"Error saving image: {e}")
  389. return None
  390. def upload_image(request, image_metadata, image_data, content_type, user):
  391. image_format = mimetypes.guess_extension(content_type)
  392. file = UploadFile(
  393. file=io.BytesIO(image_data),
  394. filename=f"generated-image{image_format}", # will be converted to a unique ID on upload_file
  395. headers={
  396. "content-type": content_type,
  397. },
  398. )
  399. file_item = upload_file(request, file, user, file_metadata=image_metadata)
  400. url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
  401. return url
  402. @router.post("/generations")
  403. async def image_generations(
  404. request: Request,
  405. form_data: GenerateImageForm,
  406. user=Depends(get_verified_user),
  407. ):
  408. width, height = tuple(map(int, request.app.state.config.IMAGE_SIZE.split("x")))
  409. r = None
  410. try:
  411. if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
  412. headers = {}
  413. headers["Authorization"] = (
  414. f"Bearer {request.app.state.config.IMAGES_OPENAI_API_KEY}"
  415. )
  416. headers["Content-Type"] = "application/json"
  417. if ENABLE_FORWARD_USER_INFO_HEADERS:
  418. headers["X-OpenWebUI-User-Name"] = user.name
  419. headers["X-OpenWebUI-User-Id"] = user.id
  420. headers["X-OpenWebUI-User-Email"] = user.email
  421. headers["X-OpenWebUI-User-Role"] = user.role
  422. data = {
  423. "model": (
  424. request.app.state.config.IMAGE_GENERATION_MODEL
  425. if request.app.state.config.IMAGE_GENERATION_MODEL != ""
  426. else "dall-e-2"
  427. ),
  428. "prompt": form_data.prompt,
  429. "n": form_data.n,
  430. "size": (
  431. form_data.size
  432. if form_data.size
  433. else request.app.state.config.IMAGE_SIZE
  434. ),
  435. "response_format": "b64_json",
  436. }
  437. # Use asyncio.to_thread for the requests.post call
  438. r = await asyncio.to_thread(
  439. requests.post,
  440. url=f"{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/generations",
  441. json=data,
  442. headers=headers,
  443. )
  444. r.raise_for_status()
  445. res = r.json()
  446. images = []
  447. for image in res["data"]:
  448. if "url" in image:
  449. image_data, content_type = load_url_image_data(
  450. image["url"], headers
  451. )
  452. else:
  453. image_data, content_type = load_b64_image_data(image["b64_json"])
  454. url = upload_image(request, data, image_data, content_type, user)
  455. images.append({"url": url})
  456. return images
  457. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
  458. headers = {}
  459. headers["Content-Type"] = "application/json"
  460. headers["x-goog-api-key"] = request.app.state.config.IMAGES_GEMINI_API_KEY
  461. model = get_image_model(request)
  462. data = {
  463. "instances": {"prompt": form_data.prompt},
  464. "parameters": {
  465. "sampleCount": form_data.n,
  466. "outputOptions": {"mimeType": "image/png"},
  467. },
  468. }
  469. # Use asyncio.to_thread for the requests.post call
  470. r = await asyncio.to_thread(
  471. requests.post,
  472. url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}:predict",
  473. json=data,
  474. headers=headers,
  475. )
  476. r.raise_for_status()
  477. res = r.json()
  478. images = []
  479. for image in res["predictions"]:
  480. image_data, content_type = load_b64_image_data(
  481. image["bytesBase64Encoded"]
  482. )
  483. url = upload_image(request, data, image_data, content_type, user)
  484. images.append({"url": url})
  485. return images
  486. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
  487. data = {
  488. "prompt": form_data.prompt,
  489. "width": width,
  490. "height": height,
  491. "n": form_data.n,
  492. }
  493. if request.app.state.config.IMAGE_STEPS is not None:
  494. data["steps"] = request.app.state.config.IMAGE_STEPS
  495. if form_data.negative_prompt is not None:
  496. data["negative_prompt"] = form_data.negative_prompt
  497. form_data = ComfyUIGenerateImageForm(
  498. **{
  499. "workflow": ComfyUIWorkflow(
  500. **{
  501. "workflow": request.app.state.config.COMFYUI_WORKFLOW,
  502. "nodes": request.app.state.config.COMFYUI_WORKFLOW_NODES,
  503. }
  504. ),
  505. **data,
  506. }
  507. )
  508. res = await comfyui_generate_image(
  509. request.app.state.config.IMAGE_GENERATION_MODEL,
  510. form_data,
  511. user.id,
  512. request.app.state.config.COMFYUI_BASE_URL,
  513. request.app.state.config.COMFYUI_API_KEY,
  514. )
  515. log.debug(f"res: {res}")
  516. images = []
  517. for image in res["data"]:
  518. headers = None
  519. if request.app.state.config.COMFYUI_API_KEY:
  520. headers = {
  521. "Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
  522. }
  523. image_data, content_type = load_url_image_data(image["url"], headers)
  524. url = upload_image(
  525. request,
  526. form_data.model_dump(exclude_none=True),
  527. image_data,
  528. content_type,
  529. user,
  530. )
  531. images.append({"url": url})
  532. return images
  533. elif (
  534. request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
  535. or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
  536. ):
  537. if form_data.model:
  538. set_image_model(form_data.model)
  539. data = {
  540. "prompt": form_data.prompt,
  541. "batch_size": form_data.n,
  542. "width": width,
  543. "height": height,
  544. }
  545. if request.app.state.config.IMAGE_STEPS is not None:
  546. data["steps"] = request.app.state.config.IMAGE_STEPS
  547. if form_data.negative_prompt is not None:
  548. data["negative_prompt"] = form_data.negative_prompt
  549. if request.app.state.config.AUTOMATIC1111_CFG_SCALE:
  550. data["cfg_scale"] = request.app.state.config.AUTOMATIC1111_CFG_SCALE
  551. if request.app.state.config.AUTOMATIC1111_SAMPLER:
  552. data["sampler_name"] = request.app.state.config.AUTOMATIC1111_SAMPLER
  553. if request.app.state.config.AUTOMATIC1111_SCHEDULER:
  554. data["scheduler"] = request.app.state.config.AUTOMATIC1111_SCHEDULER
  555. # Use asyncio.to_thread for the requests.post call
  556. r = await asyncio.to_thread(
  557. requests.post,
  558. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
  559. json=data,
  560. headers={"authorization": get_automatic1111_api_auth(request)},
  561. )
  562. res = r.json()
  563. log.debug(f"res: {res}")
  564. images = []
  565. for image in res["images"]:
  566. image_data, content_type = load_b64_image_data(image)
  567. url = upload_image(
  568. request,
  569. {**data, "info": res["info"]},
  570. image_data,
  571. content_type,
  572. user,
  573. )
  574. images.append({"url": url})
  575. return images
  576. except Exception as e:
  577. error = e
  578. if r != None:
  579. data = r.json()
  580. if "error" in data:
  581. error = data["error"]["message"]
  582. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))