main.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495
  1. from fastapi import (
  2. FastAPI,
  3. Request,
  4. Depends,
  5. HTTPException,
  6. )
  7. from fastapi.middleware.cors import CORSMiddleware
  8. from typing import Optional
  9. from pydantic import BaseModel
  10. from pathlib import Path
  11. import mimetypes
  12. import uuid
  13. import base64
  14. import json
  15. import logging
  16. import re
  17. import requests
  18. from utils.utils import (
  19. get_verified_user,
  20. get_admin_user,
  21. )
  22. from apps.images.utils.comfyui import (
  23. ComfyUIWorkflow,
  24. ComfyUIGenerateImageForm,
  25. comfyui_generate_image,
  26. )
  27. from constants import ERROR_MESSAGES
  28. from config import (
  29. SRC_LOG_LEVELS,
  30. CACHE_DIR,
  31. IMAGE_GENERATION_ENGINE,
  32. ENABLE_IMAGE_GENERATION,
  33. AUTOMATIC1111_BASE_URL,
  34. AUTOMATIC1111_API_AUTH,
  35. COMFYUI_BASE_URL,
  36. COMFYUI_WORKFLOW,
  37. IMAGES_OPENAI_API_BASE_URL,
  38. IMAGES_OPENAI_API_KEY,
  39. IMAGE_GENERATION_MODEL,
  40. IMAGE_SIZE,
  41. IMAGE_STEPS,
  42. CORS_ALLOW_ORIGIN,
  43. AppConfig,
  44. )
  45. log = logging.getLogger(__name__)
  46. log.setLevel(SRC_LOG_LEVELS["IMAGES"])
  47. IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
  48. IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  49. app = FastAPI()
  50. app.add_middleware(
  51. CORSMiddleware,
  52. allow_origins=CORS_ALLOW_ORIGIN,
  53. allow_credentials=True,
  54. allow_methods=["*"],
  55. allow_headers=["*"],
  56. )
  57. app.state.config = AppConfig()
  58. app.state.config.ENGINE = IMAGE_GENERATION_ENGINE
  59. app.state.config.ENABLED = ENABLE_IMAGE_GENERATION
  60. app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
  61. app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
  62. app.state.config.MODEL = IMAGE_GENERATION_MODEL
  63. app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  64. app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
  65. app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
  66. app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW
  67. app.state.config.IMAGE_SIZE = IMAGE_SIZE
  68. app.state.config.IMAGE_STEPS = IMAGE_STEPS
  69. @app.get("/config")
  70. async def get_config(request: Request, user=Depends(get_admin_user)):
  71. return {
  72. "enabled": app.state.config.ENABLED,
  73. "engine": app.state.config.ENGINE,
  74. "openai": {
  75. "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
  76. "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
  77. },
  78. "automatic1111": {
  79. "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
  80. "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
  81. },
  82. "comfyui": {
  83. "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
  84. "COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW,
  85. "COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES,
  86. },
  87. }
  88. class OpenAIConfigForm(BaseModel):
  89. OPENAI_API_BASE_URL: str
  90. OPENAI_API_KEY: str
  91. class Automatic1111ConfigForm(BaseModel):
  92. AUTOMATIC1111_BASE_URL: str
  93. AUTOMATIC1111_API_AUTH: str
  94. class ComfyUIConfigForm(BaseModel):
  95. COMFYUI_BASE_URL: str
  96. COMFYUI_WORKFLOW: str
  97. COMFYUI_WORKFLOW_NODES: list[dict]
  98. class ConfigForm(BaseModel):
  99. enabled: bool
  100. engine: str
  101. openai: OpenAIConfigForm
  102. automatic1111: Automatic1111ConfigForm
  103. comfyui: ComfyUIConfigForm
  104. @app.post("/config/update")
  105. async def update_config(form_data: ConfigForm, user=Depends(get_admin_user)):
  106. app.state.config.ENGINE = form_data.engine
  107. app.state.config.ENABLED = form_data.enabled
  108. app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL
  109. app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
  110. app.state.config.AUTOMATIC1111_BASE_URL = (
  111. form_data.automatic1111.AUTOMATIC1111_BASE_URL
  112. )
  113. app.state.config.AUTOMATIC1111_API_AUTH = (
  114. form_data.automatic1111.AUTOMATIC1111_API_AUTH
  115. )
  116. app.state.config.COMFYUI_BASE_URL = form_data.comfyui.COMFYUI_BASE_URL
  117. app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
  118. app.state.config.COMFYUI_WORKFLOW_NODES = form_data.comfyui.COMFYUI_WORKFLOW_NODES
  119. return {
  120. "enabled": app.state.config.ENABLED,
  121. "engine": app.state.config.ENGINE,
  122. "openai": {
  123. "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
  124. "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
  125. },
  126. "automatic1111": {
  127. "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
  128. "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
  129. },
  130. "comfyui": {
  131. "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
  132. "COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW,
  133. "COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES,
  134. },
  135. }
  136. def get_automatic1111_api_auth():
  137. if app.state.config.AUTOMATIC1111_API_AUTH is None:
  138. return ""
  139. else:
  140. auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8")
  141. auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
  142. auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
  143. return f"Basic {auth1111_base64_encoded_string}"
  144. def set_image_model(model: str):
  145. app.state.config.MODEL = model
  146. if app.state.config.ENGINE in ["", "automatic1111"]:
  147. api_auth = get_automatic1111_api_auth()
  148. r = requests.get(
  149. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  150. headers={"authorization": api_auth},
  151. )
  152. options = r.json()
  153. if model != options["sd_model_checkpoint"]:
  154. options["sd_model_checkpoint"] = model
  155. r = requests.post(
  156. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  157. json=options,
  158. headers={"authorization": api_auth},
  159. )
  160. return app.state.config.MODEL
  161. def get_image_model():
  162. if app.state.config.ENGINE == "openai":
  163. return app.state.config.MODEL if app.state.config.MODEL else "dall-e-2"
  164. elif app.state.config.ENGINE == "comfyui":
  165. return app.state.config.MODEL if app.state.config.MODEL else ""
  166. elif app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "":
  167. try:
  168. r = requests.get(
  169. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  170. headers={"authorization": get_automatic1111_api_auth()},
  171. )
  172. options = r.json()
  173. return options["sd_model_checkpoint"]
  174. except Exception as e:
  175. app.state.config.ENABLED = False
  176. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  177. class ImageConfigForm(BaseModel):
  178. model: str
  179. size: str
  180. steps: int
  181. @app.get("/image/config")
  182. async def get_image_config(user=Depends(get_admin_user)):
  183. return {
  184. "MODEL": app.state.config.MODEL,
  185. "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
  186. "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
  187. }
  188. @app.post("/image/config/update")
  189. async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin_user)):
  190. app.state.config.MODEL = form_data.model
  191. pattern = r"^\d+x\d+$"
  192. if re.match(pattern, form_data.size):
  193. app.state.config.IMAGE_SIZE = form_data.size
  194. else:
  195. raise HTTPException(
  196. status_code=400,
  197. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
  198. )
  199. if form_data.steps >= 0:
  200. app.state.config.IMAGE_STEPS = form_data.steps
  201. else:
  202. raise HTTPException(
  203. status_code=400,
  204. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
  205. )
  206. return {
  207. "MODEL": app.state.config.MODEL,
  208. "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
  209. "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
  210. }
  211. @app.get("/models")
  212. def get_models(user=Depends(get_verified_user)):
  213. try:
  214. if app.state.config.ENGINE == "openai":
  215. return [
  216. {"id": "dall-e-2", "name": "DALL·E 2"},
  217. {"id": "dall-e-3", "name": "DALL·E 3"},
  218. ]
  219. elif app.state.config.ENGINE == "comfyui":
  220. # TODO - get models from comfyui
  221. r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
  222. info = r.json()
  223. return list(
  224. map(
  225. lambda model: {"id": model, "name": model},
  226. info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0],
  227. )
  228. )
  229. elif (
  230. app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == ""
  231. ):
  232. r = requests.get(
  233. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
  234. headers={"authorization": get_automatic1111_api_auth()},
  235. )
  236. models = r.json()
  237. return list(
  238. map(
  239. lambda model: {"id": model["title"], "name": model["model_name"]},
  240. models,
  241. )
  242. )
  243. except Exception as e:
  244. app.state.config.ENABLED = False
  245. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  246. class GenerateImageForm(BaseModel):
  247. model: Optional[str] = None
  248. prompt: str
  249. size: Optional[str] = None
  250. n: int = 1
  251. negative_prompt: Optional[str] = None
  252. def save_b64_image(b64_str):
  253. try:
  254. image_id = str(uuid.uuid4())
  255. if "," in b64_str:
  256. header, encoded = b64_str.split(",", 1)
  257. mime_type = header.split(";")[0]
  258. img_data = base64.b64decode(encoded)
  259. image_format = mimetypes.guess_extension(mime_type)
  260. image_filename = f"{image_id}{image_format}"
  261. file_path = IMAGE_CACHE_DIR / f"{image_filename}"
  262. with open(file_path, "wb") as f:
  263. f.write(img_data)
  264. return image_filename
  265. else:
  266. image_filename = f"{image_id}.png"
  267. file_path = IMAGE_CACHE_DIR.joinpath(image_filename)
  268. img_data = base64.b64decode(b64_str)
  269. # Write the image data to a file
  270. with open(file_path, "wb") as f:
  271. f.write(img_data)
  272. return image_filename
  273. except Exception as e:
  274. log.exception(f"Error saving image: {e}")
  275. return None
  276. def save_url_image(url):
  277. image_id = str(uuid.uuid4())
  278. try:
  279. r = requests.get(url)
  280. r.raise_for_status()
  281. if r.headers["content-type"].split("/")[0] == "image":
  282. mime_type = r.headers["content-type"]
  283. image_format = mimetypes.guess_extension(mime_type)
  284. if not image_format:
  285. raise ValueError("Could not determine image type from MIME type")
  286. image_filename = f"{image_id}{image_format}"
  287. file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
  288. with open(file_path, "wb") as image_file:
  289. for chunk in r.iter_content(chunk_size=8192):
  290. image_file.write(chunk)
  291. return image_filename
  292. else:
  293. log.error(f"Url does not point to an image.")
  294. return None
  295. except Exception as e:
  296. log.exception(f"Error saving image: {e}")
  297. return None
  298. @app.post("/generations")
  299. async def image_generations(
  300. form_data: GenerateImageForm,
  301. user=Depends(get_verified_user),
  302. ):
  303. width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x")))
  304. r = None
  305. try:
  306. if app.state.config.ENGINE == "openai":
  307. headers = {}
  308. headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
  309. headers["Content-Type"] = "application/json"
  310. data = {
  311. "model": (
  312. app.state.config.MODEL
  313. if app.state.config.MODEL != ""
  314. else "dall-e-2"
  315. ),
  316. "prompt": form_data.prompt,
  317. "n": form_data.n,
  318. "size": (
  319. form_data.size if form_data.size else app.state.config.IMAGE_SIZE
  320. ),
  321. "response_format": "b64_json",
  322. }
  323. r = requests.post(
  324. url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations",
  325. json=data,
  326. headers=headers,
  327. )
  328. r.raise_for_status()
  329. res = r.json()
  330. images = []
  331. for image in res["data"]:
  332. image_filename = save_b64_image(image["b64_json"])
  333. images.append({"url": f"/cache/image/generations/{image_filename}"})
  334. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  335. with open(file_body_path, "w") as f:
  336. json.dump(data, f)
  337. return images
  338. elif app.state.config.ENGINE == "comfyui":
  339. data = {
  340. "prompt": form_data.prompt,
  341. "width": width,
  342. "height": height,
  343. "n": form_data.n,
  344. }
  345. if app.state.config.IMAGE_STEPS is not None:
  346. data["steps"] = app.state.config.IMAGE_STEPS
  347. if form_data.negative_prompt is not None:
  348. data["negative_prompt"] = form_data.negative_prompt
  349. form_data = ComfyUIGenerateImageForm(**data)
  350. res = await comfyui_generate_image(
  351. app.state.config.MODEL,
  352. form_data,
  353. user.id,
  354. app.state.config.COMFYUI_BASE_URL,
  355. )
  356. log.debug(f"res: {res}")
  357. images = []
  358. for image in res["data"]:
  359. image_filename = save_url_image(image["url"])
  360. images.append({"url": f"/cache/image/generations/{image_filename}"})
  361. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  362. with open(file_body_path, "w") as f:
  363. json.dump(data.model_dump(exclude_none=True), f)
  364. log.debug(f"images: {images}")
  365. return images
  366. elif (
  367. app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == ""
  368. ):
  369. if form_data.model:
  370. set_image_model(form_data.model)
  371. data = {
  372. "prompt": form_data.prompt,
  373. "batch_size": form_data.n,
  374. "width": width,
  375. "height": height,
  376. }
  377. if app.state.config.IMAGE_STEPS is not None:
  378. data["steps"] = app.state.config.IMAGE_STEPS
  379. if form_data.negative_prompt is not None:
  380. data["negative_prompt"] = form_data.negative_prompt
  381. r = requests.post(
  382. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
  383. json=data,
  384. headers={"authorization": get_automatic1111_api_auth()},
  385. )
  386. res = r.json()
  387. log.debug(f"res: {res}")
  388. images = []
  389. for image in res["images"]:
  390. image_filename = save_b64_image(image)
  391. images.append({"url": f"/cache/image/generations/{image_filename}"})
  392. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  393. with open(file_body_path, "w") as f:
  394. json.dump({**data, "info": res["info"]}, f)
  395. return images
  396. except Exception as e:
  397. error = e
  398. if r != None:
  399. data = r.json()
  400. if "error" in data:
  401. error = data["error"]["message"]
  402. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))