images.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621
  1. import asyncio
  2. import base64
  3. import json
  4. import logging
  5. import mimetypes
  6. import re
  7. import uuid
  8. from pathlib import Path
  9. from typing import Optional
  10. import requests
  11. from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter
  12. from fastapi.middleware.cors import CORSMiddleware
  13. from pydantic import BaseModel
  14. from open_webui.config import CACHE_DIR
  15. from open_webui.constants import ERROR_MESSAGES
  16. from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS
  17. from open_webui.utils.auth import get_admin_user, get_verified_user
  18. from open_webui.utils.images.comfyui import (
  19. ComfyUIGenerateImageForm,
  20. ComfyUIWorkflow,
  21. comfyui_generate_image,
  22. )
  23. log = logging.getLogger(__name__)
  24. log.setLevel(SRC_LOG_LEVELS["IMAGES"])
  25. IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
  26. IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  27. router = APIRouter()
  28. @router.get("/config")
  29. async def get_config(request: Request, user=Depends(get_admin_user)):
  30. return {
  31. "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
  32. "engine": request.app.state.config.IMAGE_GENERATION_ENGINE,
  33. "prompt_generation": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION,
  34. "openai": {
  35. "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
  36. "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
  37. },
  38. "automatic1111": {
  39. "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
  40. "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
  41. "AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE,
  42. "AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER,
  43. "AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER,
  44. },
  45. "comfyui": {
  46. "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
  47. "COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY,
  48. "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
  49. "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
  50. },
  51. }
  52. class OpenAIConfigForm(BaseModel):
  53. OPENAI_API_BASE_URL: str
  54. OPENAI_API_KEY: str
  55. class Automatic1111ConfigForm(BaseModel):
  56. AUTOMATIC1111_BASE_URL: str
  57. AUTOMATIC1111_API_AUTH: str
  58. AUTOMATIC1111_CFG_SCALE: Optional[str | float | int]
  59. AUTOMATIC1111_SAMPLER: Optional[str]
  60. AUTOMATIC1111_SCHEDULER: Optional[str]
  61. class ComfyUIConfigForm(BaseModel):
  62. COMFYUI_BASE_URL: str
  63. COMFYUI_API_KEY: str
  64. COMFYUI_WORKFLOW: str
  65. COMFYUI_WORKFLOW_NODES: list[dict]
  66. class ConfigForm(BaseModel):
  67. enabled: bool
  68. engine: str
  69. prompt_generation: bool
  70. openai: OpenAIConfigForm
  71. automatic1111: Automatic1111ConfigForm
  72. comfyui: ComfyUIConfigForm
  73. @router.post("/config/update")
  74. async def update_config(
  75. request: Request, form_data: ConfigForm, user=Depends(get_admin_user)
  76. ):
  77. request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.engine
  78. request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.enabled
  79. request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = (
  80. form_data.prompt_generation
  81. )
  82. request.app.state.config.IMAGES_OPENAI_API_BASE_URL = (
  83. form_data.openai.OPENAI_API_BASE_URL
  84. )
  85. request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
  86. request.app.state.config.AUTOMATIC1111_BASE_URL = (
  87. form_data.automatic1111.AUTOMATIC1111_BASE_URL
  88. )
  89. request.app.state.config.AUTOMATIC1111_API_AUTH = (
  90. form_data.automatic1111.AUTOMATIC1111_API_AUTH
  91. )
  92. request.app.state.config.AUTOMATIC1111_CFG_SCALE = (
  93. float(form_data.automatic1111.AUTOMATIC1111_CFG_SCALE)
  94. if form_data.automatic1111.AUTOMATIC1111_CFG_SCALE
  95. else None
  96. )
  97. request.app.state.config.AUTOMATIC1111_SAMPLER = (
  98. form_data.automatic1111.AUTOMATIC1111_SAMPLER
  99. if form_data.automatic1111.AUTOMATIC1111_SAMPLER
  100. else None
  101. )
  102. request.app.state.config.AUTOMATIC1111_SCHEDULER = (
  103. form_data.automatic1111.AUTOMATIC1111_SCHEDULER
  104. if form_data.automatic1111.AUTOMATIC1111_SCHEDULER
  105. else None
  106. )
  107. request.app.state.config.COMFYUI_BASE_URL = (
  108. form_data.comfyui.COMFYUI_BASE_URL.strip("/")
  109. )
  110. request.app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
  111. request.app.state.config.COMFYUI_WORKFLOW_NODES = (
  112. form_data.comfyui.COMFYUI_WORKFLOW_NODES
  113. )
  114. return {
  115. "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
  116. "engine": request.app.state.config.IMAGE_GENERATION_ENGINE,
  117. "prompt_generation": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION,
  118. "openai": {
  119. "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
  120. "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
  121. },
  122. "automatic1111": {
  123. "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
  124. "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
  125. "AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE,
  126. "AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER,
  127. "AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER,
  128. },
  129. "comfyui": {
  130. "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
  131. "COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY,
  132. "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
  133. "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
  134. },
  135. }
  136. def get_automatic1111_api_auth(request: Request):
  137. if request.app.state.config.AUTOMATIC1111_API_AUTH is None:
  138. return ""
  139. else:
  140. auth1111_byte_string = request.app.state.config.AUTOMATIC1111_API_AUTH.encode(
  141. "utf-8"
  142. )
  143. auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
  144. auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
  145. return f"Basic {auth1111_base64_encoded_string}"
  146. @router.get("/config/url/verify")
  147. async def verify_url(request: Request, user=Depends(get_admin_user)):
  148. if request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111":
  149. try:
  150. r = requests.get(
  151. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  152. headers={"authorization": get_automatic1111_api_auth(request)},
  153. )
  154. r.raise_for_status()
  155. return True
  156. except Exception:
  157. request.app.state.config.ENABLE_IMAGE_GENERATION = False
  158. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
  159. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
  160. try:
  161. r = requests.get(
  162. url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info"
  163. )
  164. r.raise_for_status()
  165. return True
  166. except Exception:
  167. request.app.state.config.ENABLE_IMAGE_GENERATION = False
  168. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
  169. else:
  170. return True
  171. def set_image_model(request: Request, model: str):
  172. log.info(f"Setting image model to {model}")
  173. request.app.state.config.IMAGE_GENERATION_MODEL = model
  174. if request.app.state.config.IMAGE_GENERATION_ENGINE in ["", "automatic1111"]:
  175. api_auth = get_automatic1111_api_auth(request)
  176. r = requests.get(
  177. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  178. headers={"authorization": api_auth},
  179. )
  180. options = r.json()
  181. if model != options["sd_model_checkpoint"]:
  182. options["sd_model_checkpoint"] = model
  183. r = requests.post(
  184. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  185. json=options,
  186. headers={"authorization": api_auth},
  187. )
  188. return request.app.state.config.IMAGE_GENERATION_MODEL
  189. def get_image_model(request):
  190. if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
  191. return (
  192. request.app.state.config.IMAGE_GENERATION_MODEL
  193. if request.app.state.config.IMAGE_GENERATION_MODEL
  194. else "dall-e-2"
  195. )
  196. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
  197. return (
  198. request.app.state.config.IMAGE_GENERATION_MODEL
  199. if request.app.state.config.IMAGE_GENERATION_MODEL
  200. else ""
  201. )
  202. elif (
  203. request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
  204. or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
  205. ):
  206. try:
  207. r = requests.get(
  208. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  209. headers={"authorization": get_automatic1111_api_auth(request)},
  210. )
  211. options = r.json()
  212. return options["sd_model_checkpoint"]
  213. except Exception as e:
  214. request.app.state.config.ENABLE_IMAGE_GENERATION = False
  215. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  216. class ImageConfigForm(BaseModel):
  217. MODEL: str
  218. IMAGE_SIZE: str
  219. IMAGE_STEPS: int
  220. @router.get("/image/config")
  221. async def get_image_config(request: Request, user=Depends(get_admin_user)):
  222. return {
  223. "MODEL": request.app.state.config.IMAGE_GENERATION_MODEL,
  224. "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
  225. "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
  226. }
  227. @router.post("/image/config/update")
  228. async def update_image_config(
  229. request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user)
  230. ):
  231. set_image_model(request, form_data.MODEL)
  232. pattern = r"^\d+x\d+$"
  233. if re.match(pattern, form_data.IMAGE_SIZE):
  234. request.app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
  235. else:
  236. raise HTTPException(
  237. status_code=400,
  238. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
  239. )
  240. if form_data.IMAGE_STEPS >= 0:
  241. request.app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS
  242. else:
  243. raise HTTPException(
  244. status_code=400,
  245. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
  246. )
  247. return {
  248. "MODEL": request.app.state.config.IMAGE_GENERATION_MODEL,
  249. "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
  250. "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
  251. }
  252. @router.get("/models")
  253. def get_models(request: Request, user=Depends(get_verified_user)):
  254. try:
  255. if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
  256. return [
  257. {"id": "dall-e-2", "name": "DALL·E 2"},
  258. {"id": "dall-e-3", "name": "DALL·E 3"},
  259. ]
  260. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
  261. # TODO - get models from comfyui
  262. headers = {
  263. "Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
  264. }
  265. r = requests.get(
  266. url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info",
  267. headers=headers,
  268. )
  269. info = r.json()
  270. workflow = json.loads(request.app.state.config.COMFYUI_WORKFLOW)
  271. model_node_id = None
  272. for node in request.app.state.config.COMFYUI_WORKFLOW_NODES:
  273. if node["type"] == "model":
  274. if node["node_ids"]:
  275. model_node_id = node["node_ids"][0]
  276. break
  277. if model_node_id:
  278. model_list_key = None
  279. print(workflow[model_node_id]["class_type"])
  280. for key in info[workflow[model_node_id]["class_type"]]["input"][
  281. "required"
  282. ]:
  283. if "_name" in key:
  284. model_list_key = key
  285. break
  286. if model_list_key:
  287. return list(
  288. map(
  289. lambda model: {"id": model, "name": model},
  290. info[workflow[model_node_id]["class_type"]]["input"][
  291. "required"
  292. ][model_list_key][0],
  293. )
  294. )
  295. else:
  296. return list(
  297. map(
  298. lambda model: {"id": model, "name": model},
  299. info["CheckpointLoaderSimple"]["input"]["required"][
  300. "ckpt_name"
  301. ][0],
  302. )
  303. )
  304. elif (
  305. request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
  306. or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
  307. ):
  308. r = requests.get(
  309. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
  310. headers={"authorization": get_automatic1111_api_auth(request)},
  311. )
  312. models = r.json()
  313. return list(
  314. map(
  315. lambda model: {"id": model["title"], "name": model["model_name"]},
  316. models,
  317. )
  318. )
  319. except Exception as e:
  320. request.app.state.config.ENABLE_IMAGE_GENERATION = False
  321. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  322. class GenerateImageForm(BaseModel):
  323. model: Optional[str] = None
  324. prompt: str
  325. size: Optional[str] = None
  326. n: int = 1
  327. negative_prompt: Optional[str] = None
  328. def save_b64_image(b64_str):
  329. try:
  330. image_id = str(uuid.uuid4())
  331. if "," in b64_str:
  332. header, encoded = b64_str.split(",", 1)
  333. mime_type = header.split(";")[0]
  334. img_data = base64.b64decode(encoded)
  335. image_format = mimetypes.guess_extension(mime_type)
  336. image_filename = f"{image_id}{image_format}"
  337. file_path = IMAGE_CACHE_DIR / f"{image_filename}"
  338. with open(file_path, "wb") as f:
  339. f.write(img_data)
  340. return image_filename
  341. else:
  342. image_filename = f"{image_id}.png"
  343. file_path = IMAGE_CACHE_DIR.joinpath(image_filename)
  344. img_data = base64.b64decode(b64_str)
  345. # Write the image data to a file
  346. with open(file_path, "wb") as f:
  347. f.write(img_data)
  348. return image_filename
  349. except Exception as e:
  350. log.exception(f"Error saving image: {e}")
  351. return None
  352. def save_url_image(url, headers=None):
  353. image_id = str(uuid.uuid4())
  354. try:
  355. if headers:
  356. r = requests.get(url, headers=headers)
  357. else:
  358. r = requests.get(url)
  359. r.raise_for_status()
  360. if r.headers["content-type"].split("/")[0] == "image":
  361. mime_type = r.headers["content-type"]
  362. image_format = mimetypes.guess_extension(mime_type)
  363. if not image_format:
  364. raise ValueError("Could not determine image type from MIME type")
  365. image_filename = f"{image_id}{image_format}"
  366. file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
  367. with open(file_path, "wb") as image_file:
  368. for chunk in r.iter_content(chunk_size=8192):
  369. image_file.write(chunk)
  370. return image_filename
  371. else:
  372. log.error("Url does not point to an image.")
  373. return None
  374. except Exception as e:
  375. log.exception(f"Error saving image: {e}")
  376. return None
  377. @router.post("/generations")
  378. async def image_generations(
  379. request: Request,
  380. form_data: GenerateImageForm,
  381. user=Depends(get_verified_user),
  382. ):
  383. width, height = tuple(map(int, request.app.state.config.IMAGE_SIZE.split("x")))
  384. r = None
  385. try:
  386. if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
  387. headers = {}
  388. headers["Authorization"] = (
  389. f"Bearer {request.app.state.config.IMAGES_OPENAI_API_KEY}"
  390. )
  391. headers["Content-Type"] = "application/json"
  392. if ENABLE_FORWARD_USER_INFO_HEADERS:
  393. headers["X-OpenWebUI-User-Name"] = user.name
  394. headers["X-OpenWebUI-User-Id"] = user.id
  395. headers["X-OpenWebUI-User-Email"] = user.email
  396. headers["X-OpenWebUI-User-Role"] = user.role
  397. data = {
  398. "model": (
  399. request.app.state.config.IMAGE_GENERATION_MODEL
  400. if request.app.state.config.IMAGE_GENERATION_MODEL != ""
  401. else "dall-e-2"
  402. ),
  403. "prompt": form_data.prompt,
  404. "n": form_data.n,
  405. "size": (
  406. form_data.size
  407. if form_data.size
  408. else request.app.state.config.IMAGE_SIZE
  409. ),
  410. "response_format": "b64_json",
  411. }
  412. # Use asyncio.to_thread for the requests.post call
  413. r = await asyncio.to_thread(
  414. requests.post,
  415. url=f"{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/generations",
  416. json=data,
  417. headers=headers,
  418. )
  419. r.raise_for_status()
  420. res = r.json()
  421. images = []
  422. for image in res["data"]:
  423. image_filename = save_b64_image(image["b64_json"])
  424. images.append({"url": f"/cache/image/generations/{image_filename}"})
  425. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  426. with open(file_body_path, "w") as f:
  427. json.dump(data, f)
  428. return images
  429. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
  430. data = {
  431. "prompt": form_data.prompt,
  432. "width": width,
  433. "height": height,
  434. "n": form_data.n,
  435. }
  436. if request.app.state.config.IMAGE_STEPS is not None:
  437. data["steps"] = request.app.state.config.IMAGE_STEPS
  438. if form_data.negative_prompt is not None:
  439. data["negative_prompt"] = form_data.negative_prompt
  440. form_data = ComfyUIGenerateImageForm(
  441. **{
  442. "workflow": ComfyUIWorkflow(
  443. **{
  444. "workflow": request.app.state.config.COMFYUI_WORKFLOW,
  445. "nodes": request.app.state.config.COMFYUI_WORKFLOW_NODES,
  446. }
  447. ),
  448. **data,
  449. }
  450. )
  451. res = await comfyui_generate_image(
  452. request.app.state.config.IMAGE_GENERATION_MODEL,
  453. form_data,
  454. user.id,
  455. request.app.state.config.COMFYUI_BASE_URL,
  456. request.app.state.config.COMFYUI_API_KEY,
  457. )
  458. log.debug(f"res: {res}")
  459. images = []
  460. for image in res["data"]:
  461. headers = None
  462. if request.app.state.config.COMFYUI_API_KEY:
  463. headers = {
  464. "Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
  465. }
  466. image_filename = save_url_image(image["url"], headers)
  467. images.append({"url": f"/cache/image/generations/{image_filename}"})
  468. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  469. with open(file_body_path, "w") as f:
  470. json.dump(form_data.model_dump(exclude_none=True), f)
  471. log.debug(f"images: {images}")
  472. return images
  473. elif (
  474. request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
  475. or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
  476. ):
  477. if form_data.model:
  478. set_image_model(form_data.model)
  479. data = {
  480. "prompt": form_data.prompt,
  481. "batch_size": form_data.n,
  482. "width": width,
  483. "height": height,
  484. }
  485. if request.app.state.config.IMAGE_STEPS is not None:
  486. data["steps"] = request.app.state.config.IMAGE_STEPS
  487. if form_data.negative_prompt is not None:
  488. data["negative_prompt"] = form_data.negative_prompt
  489. if request.app.state.config.AUTOMATIC1111_CFG_SCALE:
  490. data["cfg_scale"] = request.app.state.config.AUTOMATIC1111_CFG_SCALE
  491. if request.app.state.config.AUTOMATIC1111_SAMPLER:
  492. data["sampler_name"] = request.app.state.config.AUTOMATIC1111_SAMPLER
  493. if request.app.state.config.AUTOMATIC1111_SCHEDULER:
  494. data["scheduler"] = request.app.state.config.AUTOMATIC1111_SCHEDULER
  495. # Use asyncio.to_thread for the requests.post call
  496. r = await asyncio.to_thread(
  497. requests.post,
  498. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
  499. json=data,
  500. headers={"authorization": get_automatic1111_api_auth(request)},
  501. )
  502. res = r.json()
  503. log.debug(f"res: {res}")
  504. images = []
  505. for image in res["images"]:
  506. image_filename = save_b64_image(image)
  507. images.append({"url": f"/cache/image/generations/{image_filename}"})
  508. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  509. with open(file_body_path, "w") as f:
  510. json.dump({**data, "info": res["info"]}, f)
  511. return images
  512. except Exception as e:
  513. error = e
  514. if r != None:
  515. data = r.json()
  516. if "error" in data:
  517. error = data["error"]["message"]
  518. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))