main.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580
  1. import re
  2. import requests
  3. from fastapi import (
  4. FastAPI,
  5. Request,
  6. Depends,
  7. HTTPException,
  8. )
  9. from fastapi.middleware.cors import CORSMiddleware
  10. from constants import ERROR_MESSAGES
  11. from utils.utils import (
  12. get_verified_user,
  13. get_admin_user,
  14. )
  15. from apps.images.utils.comfyui import ImageGenerationPayload, comfyui_generate_image
  16. from typing import Optional
  17. from pydantic import BaseModel
  18. from pathlib import Path
  19. import mimetypes
  20. import uuid
  21. import base64
  22. import json
  23. import logging
  24. from config import (
  25. SRC_LOG_LEVELS,
  26. CACHE_DIR,
  27. IMAGE_GENERATION_ENGINE,
  28. ENABLE_IMAGE_GENERATION,
  29. AUTOMATIC1111_BASE_URL,
  30. AUTOMATIC1111_API_AUTH,
  31. COMFYUI_BASE_URL,
  32. COMFYUI_CFG_SCALE,
  33. COMFYUI_SAMPLER,
  34. COMFYUI_SCHEDULER,
  35. COMFYUI_SD3,
  36. COMFYUI_FLUX,
  37. COMFYUI_FLUX_WEIGHT_DTYPE,
  38. COMFYUI_FLUX_FP8_CLIP,
  39. IMAGES_OPENAI_API_BASE_URL,
  40. IMAGES_OPENAI_API_KEY,
  41. IMAGE_GENERATION_MODEL,
  42. IMAGE_SIZE,
  43. IMAGE_STEPS,
  44. AppConfig,
  45. CORS_ALLOW_ORIGIN,
  46. )
  47. log = logging.getLogger(__name__)
  48. log.setLevel(SRC_LOG_LEVELS["IMAGES"])
  49. IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
  50. IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  51. app = FastAPI()
  52. app.add_middleware(
  53. CORSMiddleware,
  54. allow_origins=CORS_ALLOW_ORIGIN,
  55. allow_credentials=True,
  56. allow_methods=["*"],
  57. allow_headers=["*"],
  58. )
  59. app.state.config = AppConfig()
  60. app.state.config.ENGINE = IMAGE_GENERATION_ENGINE
  61. app.state.config.ENABLED = ENABLE_IMAGE_GENERATION
  62. app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
  63. app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
  64. app.state.config.MODEL = IMAGE_GENERATION_MODEL
  65. app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  66. app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
  67. app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
  68. app.state.config.IMAGE_SIZE = IMAGE_SIZE
  69. app.state.config.IMAGE_STEPS = IMAGE_STEPS
  70. app.state.config.COMFYUI_CFG_SCALE = COMFYUI_CFG_SCALE
  71. app.state.config.COMFYUI_SAMPLER = COMFYUI_SAMPLER
  72. app.state.config.COMFYUI_SCHEDULER = COMFYUI_SCHEDULER
  73. app.state.config.COMFYUI_SD3 = COMFYUI_SD3
  74. app.state.config.COMFYUI_FLUX = COMFYUI_FLUX
  75. app.state.config.COMFYUI_FLUX_WEIGHT_DTYPE = COMFYUI_FLUX_WEIGHT_DTYPE
  76. app.state.config.COMFYUI_FLUX_FP8_CLIP = COMFYUI_FLUX_FP8_CLIP
  77. def get_automatic1111_api_auth():
  78. if app.state.config.AUTOMATIC1111_API_AUTH is None:
  79. return ""
  80. else:
  81. auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8")
  82. auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
  83. auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
  84. return f"Basic {auth1111_base64_encoded_string}"
  85. @app.get("/config")
  86. async def get_config(request: Request, user=Depends(get_admin_user)):
  87. return {
  88. "engine": app.state.config.ENGINE,
  89. "enabled": app.state.config.ENABLED,
  90. }
  91. class ConfigUpdateForm(BaseModel):
  92. engine: str
  93. enabled: bool
  94. @app.post("/config/update")
  95. async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
  96. app.state.config.ENGINE = form_data.engine
  97. app.state.config.ENABLED = form_data.enabled
  98. return {
  99. "engine": app.state.config.ENGINE,
  100. "enabled": app.state.config.ENABLED,
  101. }
  102. class EngineUrlUpdateForm(BaseModel):
  103. AUTOMATIC1111_BASE_URL: Optional[str] = None
  104. AUTOMATIC1111_API_AUTH: Optional[str] = None
  105. COMFYUI_BASE_URL: Optional[str] = None
  106. @app.get("/url")
  107. async def get_engine_url(user=Depends(get_admin_user)):
  108. return {
  109. "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
  110. "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
  111. "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
  112. }
  113. @app.post("/url/update")
  114. async def update_engine_url(
  115. form_data: EngineUrlUpdateForm, user=Depends(get_admin_user)
  116. ):
  117. if form_data.AUTOMATIC1111_BASE_URL is None:
  118. app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  119. else:
  120. url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
  121. try:
  122. r = requests.head(url)
  123. r.raise_for_status()
  124. app.state.config.AUTOMATIC1111_BASE_URL = url
  125. except Exception as e:
  126. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
  127. if form_data.COMFYUI_BASE_URL is None:
  128. app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
  129. else:
  130. url = form_data.COMFYUI_BASE_URL.strip("/")
  131. try:
  132. r = requests.head(url)
  133. r.raise_for_status()
  134. app.state.config.COMFYUI_BASE_URL = url
  135. except Exception as e:
  136. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
  137. if form_data.AUTOMATIC1111_API_AUTH is None:
  138. app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
  139. else:
  140. app.state.config.AUTOMATIC1111_API_AUTH = form_data.AUTOMATIC1111_API_AUTH
  141. return {
  142. "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
  143. "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
  144. "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
  145. "status": True,
  146. }
  147. class OpenAIConfigUpdateForm(BaseModel):
  148. url: str
  149. key: str
  150. @app.get("/openai/config")
  151. async def get_openai_config(user=Depends(get_admin_user)):
  152. return {
  153. "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
  154. "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
  155. }
  156. @app.post("/openai/config/update")
  157. async def update_openai_config(
  158. form_data: OpenAIConfigUpdateForm, user=Depends(get_admin_user)
  159. ):
  160. if form_data.key == "":
  161. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
  162. app.state.config.OPENAI_API_BASE_URL = form_data.url
  163. app.state.config.OPENAI_API_KEY = form_data.key
  164. return {
  165. "status": True,
  166. "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
  167. "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
  168. }
  169. class ImageSizeUpdateForm(BaseModel):
  170. size: str
  171. @app.get("/size")
  172. async def get_image_size(user=Depends(get_admin_user)):
  173. return {"IMAGE_SIZE": app.state.config.IMAGE_SIZE}
  174. @app.post("/size/update")
  175. async def update_image_size(
  176. form_data: ImageSizeUpdateForm, user=Depends(get_admin_user)
  177. ):
  178. pattern = r"^\d+x\d+$" # Regular expression pattern
  179. if re.match(pattern, form_data.size):
  180. app.state.config.IMAGE_SIZE = form_data.size
  181. return {
  182. "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
  183. "status": True,
  184. }
  185. else:
  186. raise HTTPException(
  187. status_code=400,
  188. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
  189. )
  190. class ImageStepsUpdateForm(BaseModel):
  191. steps: int
  192. @app.get("/steps")
  193. async def get_image_size(user=Depends(get_admin_user)):
  194. return {"IMAGE_STEPS": app.state.config.IMAGE_STEPS}
  195. @app.post("/steps/update")
  196. async def update_image_size(
  197. form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
  198. ):
  199. if form_data.steps >= 0:
  200. app.state.config.IMAGE_STEPS = form_data.steps
  201. return {
  202. "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
  203. "status": True,
  204. }
  205. else:
  206. raise HTTPException(
  207. status_code=400,
  208. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
  209. )
  210. @app.get("/models")
  211. def get_models(user=Depends(get_verified_user)):
  212. try:
  213. if app.state.config.ENGINE == "openai":
  214. return [
  215. {"id": "dall-e-2", "name": "DALL·E 2"},
  216. {"id": "dall-e-3", "name": "DALL·E 3"},
  217. ]
  218. elif app.state.config.ENGINE == "comfyui":
  219. r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
  220. info = r.json()
  221. return list(
  222. map(
  223. lambda model: {"id": model, "name": model},
  224. info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0],
  225. )
  226. )
  227. else:
  228. r = requests.get(
  229. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
  230. headers={"authorization": get_automatic1111_api_auth()},
  231. )
  232. models = r.json()
  233. return list(
  234. map(
  235. lambda model: {"id": model["title"], "name": model["model_name"]},
  236. models,
  237. )
  238. )
  239. except Exception as e:
  240. app.state.config.ENABLED = False
  241. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  242. @app.get("/models/default")
  243. async def get_default_model(user=Depends(get_admin_user)):
  244. try:
  245. if app.state.config.ENGINE == "openai":
  246. return {
  247. "model": (
  248. app.state.config.MODEL if app.state.config.MODEL else "dall-e-2"
  249. )
  250. }
  251. elif app.state.config.ENGINE == "comfyui":
  252. return {"model": (app.state.config.MODEL if app.state.config.MODEL else "")}
  253. else:
  254. r = requests.get(
  255. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  256. headers={"authorization": get_automatic1111_api_auth()},
  257. )
  258. options = r.json()
  259. return {"model": options["sd_model_checkpoint"]}
  260. except Exception as e:
  261. app.state.config.ENABLED = False
  262. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  263. class UpdateModelForm(BaseModel):
  264. model: str
  265. def set_model_handler(model: str):
  266. if app.state.config.ENGINE in ["openai", "comfyui"]:
  267. app.state.config.MODEL = model
  268. return app.state.config.MODEL
  269. else:
  270. api_auth = get_automatic1111_api_auth()
  271. r = requests.get(
  272. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  273. headers={"authorization": api_auth},
  274. )
  275. options = r.json()
  276. if model != options["sd_model_checkpoint"]:
  277. options["sd_model_checkpoint"] = model
  278. r = requests.post(
  279. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  280. json=options,
  281. headers={"authorization": api_auth},
  282. )
  283. return options
  284. @app.post("/models/default/update")
  285. def update_default_model(
  286. form_data: UpdateModelForm,
  287. user=Depends(get_verified_user),
  288. ):
  289. return set_model_handler(form_data.model)
  290. class GenerateImageForm(BaseModel):
  291. model: Optional[str] = None
  292. prompt: str
  293. n: int = 1
  294. size: Optional[str] = None
  295. negative_prompt: Optional[str] = None
  296. def save_b64_image(b64_str):
  297. try:
  298. image_id = str(uuid.uuid4())
  299. if "," in b64_str:
  300. header, encoded = b64_str.split(",", 1)
  301. mime_type = header.split(";")[0]
  302. img_data = base64.b64decode(encoded)
  303. image_format = mimetypes.guess_extension(mime_type)
  304. image_filename = f"{image_id}{image_format}"
  305. file_path = IMAGE_CACHE_DIR / f"{image_filename}"
  306. with open(file_path, "wb") as f:
  307. f.write(img_data)
  308. return image_filename
  309. else:
  310. image_filename = f"{image_id}.png"
  311. file_path = IMAGE_CACHE_DIR.joinpath(image_filename)
  312. img_data = base64.b64decode(b64_str)
  313. # Write the image data to a file
  314. with open(file_path, "wb") as f:
  315. f.write(img_data)
  316. return image_filename
  317. except Exception as e:
  318. log.exception(f"Error saving image: {e}")
  319. return None
  320. def save_url_image(url):
  321. image_id = str(uuid.uuid4())
  322. try:
  323. r = requests.get(url)
  324. r.raise_for_status()
  325. if r.headers["content-type"].split("/")[0] == "image":
  326. mime_type = r.headers["content-type"]
  327. image_format = mimetypes.guess_extension(mime_type)
  328. if not image_format:
  329. raise ValueError("Could not determine image type from MIME type")
  330. image_filename = f"{image_id}{image_format}"
  331. file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
  332. with open(file_path, "wb") as image_file:
  333. for chunk in r.iter_content(chunk_size=8192):
  334. image_file.write(chunk)
  335. return image_filename
  336. else:
  337. log.error(f"Url does not point to an image.")
  338. return None
  339. except Exception as e:
  340. log.exception(f"Error saving image: {e}")
  341. return None
  342. @app.post("/generations")
  343. async def image_generations(
  344. form_data: GenerateImageForm,
  345. user=Depends(get_verified_user),
  346. ):
  347. width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x")))
  348. r = None
  349. try:
  350. if app.state.config.ENGINE == "openai":
  351. headers = {}
  352. headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
  353. headers["Content-Type"] = "application/json"
  354. data = {
  355. "model": (
  356. app.state.config.MODEL
  357. if app.state.config.MODEL != ""
  358. else "dall-e-2"
  359. ),
  360. "prompt": form_data.prompt,
  361. "n": form_data.n,
  362. "size": (
  363. form_data.size if form_data.size else app.state.config.IMAGE_SIZE
  364. ),
  365. "response_format": "b64_json",
  366. }
  367. r = requests.post(
  368. url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations",
  369. json=data,
  370. headers=headers,
  371. )
  372. r.raise_for_status()
  373. res = r.json()
  374. images = []
  375. for image in res["data"]:
  376. image_filename = save_b64_image(image["b64_json"])
  377. images.append({"url": f"/cache/image/generations/{image_filename}"})
  378. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  379. with open(file_body_path, "w") as f:
  380. json.dump(data, f)
  381. return images
  382. elif app.state.config.ENGINE == "comfyui":
  383. data = {
  384. "prompt": form_data.prompt,
  385. "width": width,
  386. "height": height,
  387. "n": form_data.n,
  388. }
  389. if app.state.config.IMAGE_STEPS is not None:
  390. data["steps"] = app.state.config.IMAGE_STEPS
  391. if form_data.negative_prompt is not None:
  392. data["negative_prompt"] = form_data.negative_prompt
  393. if app.state.config.COMFYUI_CFG_SCALE:
  394. data["cfg_scale"] = app.state.config.COMFYUI_CFG_SCALE
  395. if app.state.config.COMFYUI_SAMPLER is not None:
  396. data["sampler"] = app.state.config.COMFYUI_SAMPLER
  397. if app.state.config.COMFYUI_SCHEDULER is not None:
  398. data["scheduler"] = app.state.config.COMFYUI_SCHEDULER
  399. if app.state.config.COMFYUI_SD3 is not None:
  400. data["sd3"] = app.state.config.COMFYUI_SD3
  401. if app.state.config.COMFYUI_FLUX is not None:
  402. data["flux"] = app.state.config.COMFYUI_FLUX
  403. if app.state.config.COMFYUI_FLUX_WEIGHT_DTYPE is not None:
  404. data["flux_weight_dtype"] = app.state.config.COMFYUI_FLUX_WEIGHT_DTYPE
  405. if app.state.config.COMFYUI_FLUX_FP8_CLIP is not None:
  406. data["flux_fp8_clip"] = app.state.config.COMFYUI_FLUX_FP8_CLIP
  407. data = ImageGenerationPayload(**data)
  408. res = await comfyui_generate_image(
  409. app.state.config.MODEL,
  410. data,
  411. user.id,
  412. app.state.config.COMFYUI_BASE_URL,
  413. )
  414. log.debug(f"res: {res}")
  415. images = []
  416. for image in res["data"]:
  417. image_filename = save_url_image(image["url"])
  418. images.append({"url": f"/cache/image/generations/{image_filename}"})
  419. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  420. with open(file_body_path, "w") as f:
  421. json.dump(data.model_dump(exclude_none=True), f)
  422. log.debug(f"images: {images}")
  423. return images
  424. else:
  425. if form_data.model:
  426. set_model_handler(form_data.model)
  427. data = {
  428. "prompt": form_data.prompt,
  429. "batch_size": form_data.n,
  430. "width": width,
  431. "height": height,
  432. }
  433. if app.state.config.IMAGE_STEPS is not None:
  434. data["steps"] = app.state.config.IMAGE_STEPS
  435. if form_data.negative_prompt is not None:
  436. data["negative_prompt"] = form_data.negative_prompt
  437. r = requests.post(
  438. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
  439. json=data,
  440. headers={"authorization": get_automatic1111_api_auth()},
  441. )
  442. res = r.json()
  443. log.debug(f"res: {res}")
  444. images = []
  445. for image in res["images"]:
  446. image_filename = save_b64_image(image)
  447. images.append({"url": f"/cache/image/generations/{image_filename}"})
  448. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  449. with open(file_body_path, "w") as f:
  450. json.dump({**data, "info": res["info"]}, f)
  451. return images
  452. except Exception as e:
  453. error = e
  454. if r != None:
  455. data = r.json()
  456. if "error" in data:
  457. error = data["error"]["message"]
  458. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))