main.py 18 KB


  1. import asyncio
  2. import base64
  3. import json
  4. import logging
  5. import mimetypes
  6. import re
  7. import uuid
  8. from pathlib import Path
  9. from typing import Optional
  10. import requests
  11. from apps.images.utils.comfyui import (
  12. ComfyUIGenerateImageForm,
  13. ComfyUIWorkflow,
  14. comfyui_generate_image,
  15. )
  16. from config import (
  17. AUTOMATIC1111_API_AUTH,
  18. AUTOMATIC1111_BASE_URL,
  19. CACHE_DIR,
  20. COMFYUI_BASE_URL,
  21. COMFYUI_WORKFLOW,
  22. COMFYUI_WORKFLOW_NODES,
  23. CORS_ALLOW_ORIGIN,
  24. ENABLE_IMAGE_GENERATION,
  25. IMAGE_GENERATION_ENGINE,
  26. IMAGE_GENERATION_MODEL,
  27. IMAGE_SIZE,
  28. IMAGE_STEPS,
  29. IMAGES_OPENAI_API_BASE_URL,
  30. IMAGES_OPENAI_API_KEY,
  31. AppConfig,
  32. )
  33. from constants import ERROR_MESSAGES
  34. from env import SRC_LOG_LEVELS
  35. from fastapi import Depends, FastAPI, HTTPException, Request
  36. from fastapi.middleware.cors import CORSMiddleware
  37. from pydantic import BaseModel
  38. from utils.utils import get_admin_user, get_verified_user
  39. log = logging.getLogger(__name__)
  40. log.setLevel(SRC_LOG_LEVELS["IMAGES"])
  41. IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
  42. IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  43. app = FastAPI()
  44. app.add_middleware(
  45. CORSMiddleware,
  46. allow_origins=CORS_ALLOW_ORIGIN,
  47. allow_credentials=True,
  48. allow_methods=["*"],
  49. allow_headers=["*"],
  50. )
  51. app.state.config = AppConfig()
  52. app.state.config.ENGINE = IMAGE_GENERATION_ENGINE
  53. app.state.config.ENABLED = ENABLE_IMAGE_GENERATION
  54. app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
  55. app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
  56. app.state.config.MODEL = IMAGE_GENERATION_MODEL
  57. app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  58. app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
  59. app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
  60. app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW
  61. app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES
  62. app.state.config.IMAGE_SIZE = IMAGE_SIZE
  63. app.state.config.IMAGE_STEPS = IMAGE_STEPS
  64. @app.get("/config")
  65. async def get_config(request: Request, user=Depends(get_admin_user)):
  66. return {
  67. "enabled": app.state.config.ENABLED,
  68. "engine": app.state.config.ENGINE,
  69. "openai": {
  70. "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
  71. "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
  72. },
  73. "automatic1111": {
  74. "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
  75. "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
  76. },
  77. "comfyui": {
  78. "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
  79. "COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW,
  80. "COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES,
  81. },
  82. }
  83. class OpenAIConfigForm(BaseModel):
  84. OPENAI_API_BASE_URL: str
  85. OPENAI_API_KEY: str
  86. class Automatic1111ConfigForm(BaseModel):
  87. AUTOMATIC1111_BASE_URL: str
  88. AUTOMATIC1111_API_AUTH: str
  89. class ComfyUIConfigForm(BaseModel):
  90. COMFYUI_BASE_URL: str
  91. COMFYUI_WORKFLOW: str
  92. COMFYUI_WORKFLOW_NODES: list[dict]
  93. class ConfigForm(BaseModel):
  94. enabled: bool
  95. engine: str
  96. openai: OpenAIConfigForm
  97. automatic1111: Automatic1111ConfigForm
  98. comfyui: ComfyUIConfigForm
  99. @app.post("/config/update")
  100. async def update_config(form_data: ConfigForm, user=Depends(get_admin_user)):
  101. app.state.config.ENGINE = form_data.engine
  102. app.state.config.ENABLED = form_data.enabled
  103. app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL
  104. app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
  105. app.state.config.AUTOMATIC1111_BASE_URL = (
  106. form_data.automatic1111.AUTOMATIC1111_BASE_URL
  107. )
  108. app.state.config.AUTOMATIC1111_API_AUTH = (
  109. form_data.automatic1111.AUTOMATIC1111_API_AUTH
  110. )
  111. app.state.config.COMFYUI_BASE_URL = form_data.comfyui.COMFYUI_BASE_URL
  112. app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
  113. app.state.config.COMFYUI_WORKFLOW_NODES = form_data.comfyui.COMFYUI_WORKFLOW_NODES
  114. return {
  115. "enabled": app.state.config.ENABLED,
  116. "engine": app.state.config.ENGINE,
  117. "openai": {
  118. "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
  119. "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
  120. },
  121. "automatic1111": {
  122. "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
  123. "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
  124. },
  125. "comfyui": {
  126. "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
  127. "COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW,
  128. "COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES,
  129. },
  130. }
  131. def get_automatic1111_api_auth():
  132. if app.state.config.AUTOMATIC1111_API_AUTH is None:
  133. return ""
  134. else:
  135. auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8")
  136. auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
  137. auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
  138. return f"Basic {auth1111_base64_encoded_string}"
  139. @app.get("/config/url/verify")
  140. async def verify_url(user=Depends(get_admin_user)):
  141. if app.state.config.ENGINE == "automatic1111":
  142. try:
  143. r = requests.get(
  144. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  145. headers={"authorization": get_automatic1111_api_auth()},
  146. )
  147. r.raise_for_status()
  148. return True
  149. except Exception:
  150. app.state.config.ENABLED = False
  151. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
  152. elif app.state.config.ENGINE == "comfyui":
  153. try:
  154. r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
  155. r.raise_for_status()
  156. return True
  157. except Exception:
  158. app.state.config.ENABLED = False
  159. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
  160. else:
  161. return True
  162. def set_image_model(model: str):
  163. app.state.config.MODEL = model
  164. if app.state.config.ENGINE in ["", "automatic1111"]:
  165. api_auth = get_automatic1111_api_auth()
  166. r = requests.get(
  167. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  168. headers={"authorization": api_auth},
  169. )
  170. options = r.json()
  171. if model != options["sd_model_checkpoint"]:
  172. options["sd_model_checkpoint"] = model
  173. r = requests.post(
  174. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  175. json=options,
  176. headers={"authorization": api_auth},
  177. )
  178. return app.state.config.MODEL
  179. def get_image_model():
  180. if app.state.config.ENGINE == "openai":
  181. return app.state.config.MODEL if app.state.config.MODEL else "dall-e-2"
  182. elif app.state.config.ENGINE == "comfyui":
  183. return app.state.config.MODEL if app.state.config.MODEL else ""
  184. elif app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "":
  185. try:
  186. r = requests.get(
  187. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  188. headers={"authorization": get_automatic1111_api_auth()},
  189. )
  190. options = r.json()
  191. return options["sd_model_checkpoint"]
  192. except Exception as e:
  193. app.state.config.ENABLED = False
  194. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  195. class ImageConfigForm(BaseModel):
  196. MODEL: str
  197. IMAGE_SIZE: str
  198. IMAGE_STEPS: int
  199. @app.get("/image/config")
  200. async def get_image_config(user=Depends(get_admin_user)):
  201. return {
  202. "MODEL": app.state.config.MODEL,
  203. "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
  204. "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
  205. }
  206. @app.post("/image/config/update")
  207. async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin_user)):
  208. app.state.config.MODEL = form_data.MODEL
  209. pattern = r"^\d+x\d+$"
  210. if re.match(pattern, form_data.IMAGE_SIZE):
  211. app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
  212. else:
  213. raise HTTPException(
  214. status_code=400,
  215. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
  216. )
  217. if form_data.IMAGE_STEPS >= 0:
  218. app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS
  219. else:
  220. raise HTTPException(
  221. status_code=400,
  222. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
  223. )
  224. return {
  225. "MODEL": app.state.config.MODEL,
  226. "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
  227. "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
  228. }
  229. @app.get("/models")
  230. def get_models(user=Depends(get_verified_user)):
  231. try:
  232. if app.state.config.ENGINE == "openai":
  233. return [
  234. {"id": "dall-e-2", "name": "DALL·E 2"},
  235. {"id": "dall-e-3", "name": "DALL·E 3"},
  236. ]
  237. elif app.state.config.ENGINE == "comfyui":
  238. # TODO - get models from comfyui
  239. r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
  240. info = r.json()
  241. workflow = json.loads(app.state.config.COMFYUI_WORKFLOW)
  242. model_node_id = None
  243. for node in app.state.config.COMFYUI_WORKFLOW_NODES:
  244. if node["type"] == "model":
  245. if node["node_ids"]:
  246. model_node_id = node["node_ids"][0]
  247. break
  248. if model_node_id:
  249. model_list_key = None
  250. print(workflow[model_node_id]["class_type"])
  251. for key in info[workflow[model_node_id]["class_type"]]["input"][
  252. "required"
  253. ]:
  254. if "_name" in key:
  255. model_list_key = key
  256. break
  257. if model_list_key:
  258. return list(
  259. map(
  260. lambda model: {"id": model, "name": model},
  261. info[workflow[model_node_id]["class_type"]]["input"][
  262. "required"
  263. ][model_list_key][0],
  264. )
  265. )
  266. else:
  267. return list(
  268. map(
  269. lambda model: {"id": model, "name": model},
  270. info["CheckpointLoaderSimple"]["input"]["required"][
  271. "ckpt_name"
  272. ][0],
  273. )
  274. )
  275. elif (
  276. app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == ""
  277. ):
  278. r = requests.get(
  279. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
  280. headers={"authorization": get_automatic1111_api_auth()},
  281. )
  282. models = r.json()
  283. return list(
  284. map(
  285. lambda model: {"id": model["title"], "name": model["model_name"]},
  286. models,
  287. )
  288. )
  289. except Exception as e:
  290. app.state.config.ENABLED = False
  291. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  292. class GenerateImageForm(BaseModel):
  293. model: Optional[str] = None
  294. prompt: str
  295. size: Optional[str] = None
  296. n: int = 1
  297. negative_prompt: Optional[str] = None
  298. def save_b64_image(b64_str):
  299. try:
  300. image_id = str(uuid.uuid4())
  301. if "," in b64_str:
  302. header, encoded = b64_str.split(",", 1)
  303. mime_type = header.split(";")[0]
  304. img_data = base64.b64decode(encoded)
  305. image_format = mimetypes.guess_extension(mime_type)
  306. image_filename = f"{image_id}{image_format}"
  307. file_path = IMAGE_CACHE_DIR / f"{image_filename}"
  308. with open(file_path, "wb") as f:
  309. f.write(img_data)
  310. return image_filename
  311. else:
  312. image_filename = f"{image_id}.png"
  313. file_path = IMAGE_CACHE_DIR.joinpath(image_filename)
  314. img_data = base64.b64decode(b64_str)
  315. # Write the image data to a file
  316. with open(file_path, "wb") as f:
  317. f.write(img_data)
  318. return image_filename
  319. except Exception as e:
  320. log.exception(f"Error saving image: {e}")
  321. return None
  322. def save_url_image(url):
  323. image_id = str(uuid.uuid4())
  324. try:
  325. r = requests.get(url)
  326. r.raise_for_status()
  327. if r.headers["content-type"].split("/")[0] == "image":
  328. mime_type = r.headers["content-type"]
  329. image_format = mimetypes.guess_extension(mime_type)
  330. if not image_format:
  331. raise ValueError("Could not determine image type from MIME type")
  332. image_filename = f"{image_id}{image_format}"
  333. file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
  334. with open(file_path, "wb") as image_file:
  335. for chunk in r.iter_content(chunk_size=8192):
  336. image_file.write(chunk)
  337. return image_filename
  338. else:
  339. log.error("Url does not point to an image.")
  340. return None
  341. except Exception as e:
  342. log.exception(f"Error saving image: {e}")
  343. return None
  344. @app.post("/generations")
  345. async def image_generations(
  346. form_data: GenerateImageForm,
  347. user=Depends(get_verified_user),
  348. ):
  349. width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x")))
  350. r = None
  351. try:
  352. if app.state.config.ENGINE == "openai":
  353. headers = {}
  354. headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
  355. headers["Content-Type"] = "application/json"
  356. data = {
  357. "model": (
  358. app.state.config.MODEL
  359. if app.state.config.MODEL != ""
  360. else "dall-e-2"
  361. ),
  362. "prompt": form_data.prompt,
  363. "n": form_data.n,
  364. "size": (
  365. form_data.size if form_data.size else app.state.config.IMAGE_SIZE
  366. ),
  367. "response_format": "b64_json",
  368. }
  369. r = requests.post(
  370. url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations",
  371. json=data,
  372. headers=headers,
  373. )
  374. r.raise_for_status()
  375. res = r.json()
  376. images = []
  377. for image in res["data"]:
  378. image_filename = save_b64_image(image["b64_json"])
  379. images.append({"url": f"/cache/image/generations/{image_filename}"})
  380. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  381. with open(file_body_path, "w") as f:
  382. json.dump(data, f)
  383. return images
  384. elif app.state.config.ENGINE == "comfyui":
  385. data = {
  386. "prompt": form_data.prompt,
  387. "width": width,
  388. "height": height,
  389. "n": form_data.n,
  390. }
  391. if app.state.config.IMAGE_STEPS is not None:
  392. data["steps"] = app.state.config.IMAGE_STEPS
  393. if form_data.negative_prompt is not None:
  394. data["negative_prompt"] = form_data.negative_prompt
  395. form_data = ComfyUIGenerateImageForm(
  396. **{
  397. "workflow": ComfyUIWorkflow(
  398. **{
  399. "workflow": app.state.config.COMFYUI_WORKFLOW,
  400. "nodes": app.state.config.COMFYUI_WORKFLOW_NODES,
  401. }
  402. ),
  403. **data,
  404. }
  405. )
  406. res = await comfyui_generate_image(
  407. app.state.config.MODEL,
  408. form_data,
  409. user.id,
  410. app.state.config.COMFYUI_BASE_URL,
  411. )
  412. log.debug(f"res: {res}")
  413. images = []
  414. for image in res["data"]:
  415. image_filename = save_url_image(image["url"])
  416. images.append({"url": f"/cache/image/generations/{image_filename}"})
  417. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  418. with open(file_body_path, "w") as f:
  419. json.dump(form_data.model_dump(exclude_none=True), f)
  420. log.debug(f"images: {images}")
  421. return images
  422. elif (
  423. app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == ""
  424. ):
  425. if form_data.model:
  426. set_image_model(form_data.model)
  427. data = {
  428. "prompt": form_data.prompt,
  429. "batch_size": form_data.n,
  430. "width": width,
  431. "height": height,
  432. }
  433. if app.state.config.IMAGE_STEPS is not None:
  434. data["steps"] = app.state.config.IMAGE_STEPS
  435. if form_data.negative_prompt is not None:
  436. data["negative_prompt"] = form_data.negative_prompt
  437. # Use asyncio.to_thread for the requests.post call
  438. r = await asyncio.to_thread(
  439. requests.post,
  440. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
  441. json=data,
  442. headers={"authorization": get_automatic1111_api_auth()},
  443. )
  444. res = r.json()
  445. log.debug(f"res: {res}")
  446. images = []
  447. for image in res["images"]:
  448. image_filename = save_b64_image(image)
  449. images.append({"url": f"/cache/image/generations/{image_filename}"})
  450. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  451. with open(file_body_path, "w") as f:
  452. json.dump({**data, "info": res["info"]}, f)
  453. return images
  454. except Exception as e:
  455. error = e
  456. if r != None:
  457. data = r.json()
  458. if "error" in data:
  459. error = data["error"]["message"]
  460. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))