main.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
  1. from fastapi import (
  2. FastAPI,
  3. Request,
  4. Depends,
  5. HTTPException,
  6. )
  7. from fastapi.middleware.cors import CORSMiddleware
  8. from typing import Optional
  9. from pydantic import BaseModel
  10. from pathlib import Path
  11. import mimetypes
  12. import uuid
  13. import base64
  14. import json
  15. import logging
  16. import re
  17. import requests
  18. from utils.utils import (
  19. get_verified_user,
  20. get_admin_user,
  21. )
  22. from apps.images.utils.comfyui import (
  23. ComfyUIWorkflow,
  24. ComfyUIGenerateImageForm,
  25. comfyui_generate_image,
  26. )
  27. from constants import ERROR_MESSAGES
  28. from config import (
  29. SRC_LOG_LEVELS,
  30. CACHE_DIR,
  31. IMAGE_GENERATION_ENGINE,
  32. ENABLE_IMAGE_GENERATION,
  33. AUTOMATIC1111_BASE_URL,
  34. AUTOMATIC1111_API_AUTH,
  35. COMFYUI_BASE_URL,
  36. COMFYUI_WORKFLOW,
  37. COMFYUI_WORKFLOW_NODES,
  38. IMAGES_OPENAI_API_BASE_URL,
  39. IMAGES_OPENAI_API_KEY,
  40. IMAGE_GENERATION_MODEL,
  41. IMAGE_SIZE,
  42. IMAGE_STEPS,
  43. CORS_ALLOW_ORIGIN,
  44. AppConfig,
  45. )
  46. log = logging.getLogger(__name__)
  47. log.setLevel(SRC_LOG_LEVELS["IMAGES"])
  48. IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
  49. IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  50. app = FastAPI()
  51. app.add_middleware(
  52. CORSMiddleware,
  53. allow_origins=CORS_ALLOW_ORIGIN,
  54. allow_credentials=True,
  55. allow_methods=["*"],
  56. allow_headers=["*"],
  57. )
  58. app.state.config = AppConfig()
  59. app.state.config.ENGINE = IMAGE_GENERATION_ENGINE
  60. app.state.config.ENABLED = ENABLE_IMAGE_GENERATION
  61. app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
  62. app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
  63. app.state.config.MODEL = IMAGE_GENERATION_MODEL
  64. app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  65. app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
  66. app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
  67. app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW
  68. app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES
  69. app.state.config.IMAGE_SIZE = IMAGE_SIZE
  70. app.state.config.IMAGE_STEPS = IMAGE_STEPS
  71. @app.get("/config")
  72. async def get_config(request: Request, user=Depends(get_admin_user)):
  73. return {
  74. "enabled": app.state.config.ENABLED,
  75. "engine": app.state.config.ENGINE,
  76. "openai": {
  77. "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
  78. "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
  79. },
  80. "automatic1111": {
  81. "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
  82. "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
  83. },
  84. "comfyui": {
  85. "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
  86. "COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW,
  87. "COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES,
  88. },
  89. }
  90. class OpenAIConfigForm(BaseModel):
  91. OPENAI_API_BASE_URL: str
  92. OPENAI_API_KEY: str
  93. class Automatic1111ConfigForm(BaseModel):
  94. AUTOMATIC1111_BASE_URL: str
  95. AUTOMATIC1111_API_AUTH: str
  96. class ComfyUIConfigForm(BaseModel):
  97. COMFYUI_BASE_URL: str
  98. COMFYUI_WORKFLOW: str
  99. COMFYUI_WORKFLOW_NODES: list[dict]
  100. class ConfigForm(BaseModel):
  101. enabled: bool
  102. engine: str
  103. openai: OpenAIConfigForm
  104. automatic1111: Automatic1111ConfigForm
  105. comfyui: ComfyUIConfigForm
  106. @app.post("/config/update")
  107. async def update_config(form_data: ConfigForm, user=Depends(get_admin_user)):
  108. app.state.config.ENGINE = form_data.engine
  109. app.state.config.ENABLED = form_data.enabled
  110. app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL
  111. app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
  112. app.state.config.AUTOMATIC1111_BASE_URL = (
  113. form_data.automatic1111.AUTOMATIC1111_BASE_URL
  114. )
  115. app.state.config.AUTOMATIC1111_API_AUTH = (
  116. form_data.automatic1111.AUTOMATIC1111_API_AUTH
  117. )
  118. app.state.config.COMFYUI_BASE_URL = form_data.comfyui.COMFYUI_BASE_URL
  119. app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
  120. app.state.config.COMFYUI_WORKFLOW_NODES = form_data.comfyui.COMFYUI_WORKFLOW_NODES
  121. return {
  122. "enabled": app.state.config.ENABLED,
  123. "engine": app.state.config.ENGINE,
  124. "openai": {
  125. "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
  126. "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
  127. },
  128. "automatic1111": {
  129. "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
  130. "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
  131. },
  132. "comfyui": {
  133. "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
  134. "COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW,
  135. "COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES,
  136. },
  137. }
  138. def get_automatic1111_api_auth():
  139. if app.state.config.AUTOMATIC1111_API_AUTH is None:
  140. return ""
  141. else:
  142. auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8")
  143. auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
  144. auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
  145. return f"Basic {auth1111_base64_encoded_string}"
  146. @app.get("/config/url/verify")
  147. async def verify_url(user=Depends(get_admin_user)):
  148. if app.state.config.ENGINE == "automatic1111":
  149. try:
  150. r = requests.get(
  151. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  152. headers={"authorization": get_automatic1111_api_auth()},
  153. )
  154. r.raise_for_status()
  155. return True
  156. except Exception as e:
  157. app.state.config.ENABLED = False
  158. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
  159. elif app.state.config.ENGINE == "comfyui":
  160. try:
  161. r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
  162. r.raise_for_status()
  163. return True
  164. except Exception as e:
  165. app.state.config.ENABLED = False
  166. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
  167. else:
  168. return True
  169. def set_image_model(model: str):
  170. app.state.config.MODEL = model
  171. if app.state.config.ENGINE in ["", "automatic1111"]:
  172. api_auth = get_automatic1111_api_auth()
  173. r = requests.get(
  174. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  175. headers={"authorization": api_auth},
  176. )
  177. options = r.json()
  178. if model != options["sd_model_checkpoint"]:
  179. options["sd_model_checkpoint"] = model
  180. r = requests.post(
  181. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  182. json=options,
  183. headers={"authorization": api_auth},
  184. )
  185. return app.state.config.MODEL
  186. def get_image_model():
  187. if app.state.config.ENGINE == "openai":
  188. return app.state.config.MODEL if app.state.config.MODEL else "dall-e-2"
  189. elif app.state.config.ENGINE == "comfyui":
  190. return app.state.config.MODEL if app.state.config.MODEL else ""
  191. elif app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "":
  192. try:
  193. r = requests.get(
  194. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  195. headers={"authorization": get_automatic1111_api_auth()},
  196. )
  197. options = r.json()
  198. return options["sd_model_checkpoint"]
  199. except Exception as e:
  200. app.state.config.ENABLED = False
  201. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  202. class ImageConfigForm(BaseModel):
  203. MODEL: str
  204. IMAGE_SIZE: str
  205. IMAGE_STEPS: int
  206. @app.get("/image/config")
  207. async def get_image_config(user=Depends(get_admin_user)):
  208. return {
  209. "MODEL": app.state.config.MODEL,
  210. "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
  211. "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
  212. }
  213. @app.post("/image/config/update")
  214. async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin_user)):
  215. app.state.config.MODEL = form_data.MODEL
  216. pattern = r"^\d+x\d+$"
  217. if re.match(pattern, form_data.IMAGE_SIZE):
  218. app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
  219. else:
  220. raise HTTPException(
  221. status_code=400,
  222. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
  223. )
  224. if form_data.IMAGE_STEPS >= 0:
  225. app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS
  226. else:
  227. raise HTTPException(
  228. status_code=400,
  229. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
  230. )
  231. return {
  232. "MODEL": app.state.config.MODEL,
  233. "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
  234. "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
  235. }
  236. @app.get("/models")
  237. def get_models(user=Depends(get_verified_user)):
  238. try:
  239. if app.state.config.ENGINE == "openai":
  240. return [
  241. {"id": "dall-e-2", "name": "DALL·E 2"},
  242. {"id": "dall-e-3", "name": "DALL·E 3"},
  243. ]
  244. elif app.state.config.ENGINE == "comfyui":
  245. # TODO - get models from comfyui
  246. r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
  247. info = r.json()
  248. workflow = json.loads(app.state.config.COMFYUI_WORKFLOW)
  249. model_node_id = None
  250. for node in app.state.config.COMFYUI_WORKFLOW_NODES:
  251. if node["type"] == "model":
  252. if node["node_ids"]:
  253. model_node_id = node["node_ids"][0]
  254. break
  255. if model_node_id:
  256. model_list_key = None
  257. print(workflow[model_node_id]["class_type"])
  258. for key in info[workflow[model_node_id]["class_type"]]["input"][
  259. "required"
  260. ]:
  261. if "_name" in key:
  262. model_list_key = key
  263. break
  264. if model_list_key:
  265. return list(
  266. map(
  267. lambda model: {"id": model, "name": model},
  268. info[workflow[model_node_id]["class_type"]]["input"][
  269. "required"
  270. ][model_list_key][0],
  271. )
  272. )
  273. else:
  274. return list(
  275. map(
  276. lambda model: {"id": model, "name": model},
  277. info["CheckpointLoaderSimple"]["input"]["required"][
  278. "ckpt_name"
  279. ][0],
  280. )
  281. )
  282. elif (
  283. app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == ""
  284. ):
  285. r = requests.get(
  286. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
  287. headers={"authorization": get_automatic1111_api_auth()},
  288. )
  289. models = r.json()
  290. return list(
  291. map(
  292. lambda model: {"id": model["title"], "name": model["model_name"]},
  293. models,
  294. )
  295. )
  296. except Exception as e:
  297. app.state.config.ENABLED = False
  298. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  299. class GenerateImageForm(BaseModel):
  300. model: Optional[str] = None
  301. prompt: str
  302. size: Optional[str] = None
  303. n: int = 1
  304. negative_prompt: Optional[str] = None
  305. def save_b64_image(b64_str):
  306. try:
  307. image_id = str(uuid.uuid4())
  308. if "," in b64_str:
  309. header, encoded = b64_str.split(",", 1)
  310. mime_type = header.split(";")[0]
  311. img_data = base64.b64decode(encoded)
  312. image_format = mimetypes.guess_extension(mime_type)
  313. image_filename = f"{image_id}{image_format}"
  314. file_path = IMAGE_CACHE_DIR / f"{image_filename}"
  315. with open(file_path, "wb") as f:
  316. f.write(img_data)
  317. return image_filename
  318. else:
  319. image_filename = f"{image_id}.png"
  320. file_path = IMAGE_CACHE_DIR.joinpath(image_filename)
  321. img_data = base64.b64decode(b64_str)
  322. # Write the image data to a file
  323. with open(file_path, "wb") as f:
  324. f.write(img_data)
  325. return image_filename
  326. except Exception as e:
  327. log.exception(f"Error saving image: {e}")
  328. return None
  329. def save_url_image(url):
  330. image_id = str(uuid.uuid4())
  331. try:
  332. r = requests.get(url)
  333. r.raise_for_status()
  334. if r.headers["content-type"].split("/")[0] == "image":
  335. mime_type = r.headers["content-type"]
  336. image_format = mimetypes.guess_extension(mime_type)
  337. if not image_format:
  338. raise ValueError("Could not determine image type from MIME type")
  339. image_filename = f"{image_id}{image_format}"
  340. file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
  341. with open(file_path, "wb") as image_file:
  342. for chunk in r.iter_content(chunk_size=8192):
  343. image_file.write(chunk)
  344. return image_filename
  345. else:
  346. log.error(f"Url does not point to an image.")
  347. return None
  348. except Exception as e:
  349. log.exception(f"Error saving image: {e}")
  350. return None
  351. @app.post("/generations")
  352. async def image_generations(
  353. form_data: GenerateImageForm,
  354. user=Depends(get_verified_user),
  355. ):
  356. width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x")))
  357. r = None
  358. try:
  359. if app.state.config.ENGINE == "openai":
  360. headers = {}
  361. headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
  362. headers["Content-Type"] = "application/json"
  363. data = {
  364. "model": (
  365. app.state.config.MODEL
  366. if app.state.config.MODEL != ""
  367. else "dall-e-2"
  368. ),
  369. "prompt": form_data.prompt,
  370. "n": form_data.n,
  371. "size": (
  372. form_data.size if form_data.size else app.state.config.IMAGE_SIZE
  373. ),
  374. "response_format": "b64_json",
  375. }
  376. r = requests.post(
  377. url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations",
  378. json=data,
  379. headers=headers,
  380. )
  381. r.raise_for_status()
  382. res = r.json()
  383. images = []
  384. for image in res["data"]:
  385. image_filename = save_b64_image(image["b64_json"])
  386. images.append({"url": f"/cache/image/generations/{image_filename}"})
  387. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  388. with open(file_body_path, "w") as f:
  389. json.dump(data, f)
  390. return images
  391. elif app.state.config.ENGINE == "comfyui":
  392. data = {
  393. "prompt": form_data.prompt,
  394. "width": width,
  395. "height": height,
  396. "n": form_data.n,
  397. }
  398. if app.state.config.IMAGE_STEPS is not None:
  399. data["steps"] = app.state.config.IMAGE_STEPS
  400. if form_data.negative_prompt is not None:
  401. data["negative_prompt"] = form_data.negative_prompt
  402. form_data = ComfyUIGenerateImageForm(
  403. **{
  404. "workflow": ComfyUIWorkflow(
  405. **{
  406. "workflow": app.state.config.COMFYUI_WORKFLOW,
  407. "nodes": app.state.config.COMFYUI_WORKFLOW_NODES,
  408. }
  409. ),
  410. **data,
  411. }
  412. )
  413. res = await comfyui_generate_image(
  414. app.state.config.MODEL,
  415. form_data,
  416. user.id,
  417. app.state.config.COMFYUI_BASE_URL,
  418. )
  419. log.debug(f"res: {res}")
  420. images = []
  421. for image in res["data"]:
  422. image_filename = save_url_image(image["url"])
  423. images.append({"url": f"/cache/image/generations/{image_filename}"})
  424. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  425. with open(file_body_path, "w") as f:
  426. json.dump(form_data.model_dump(exclude_none=True), f)
  427. log.debug(f"images: {images}")
  428. return images
  429. elif (
  430. app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == ""
  431. ):
  432. if form_data.model:
  433. set_image_model(form_data.model)
  434. data = {
  435. "prompt": form_data.prompt,
  436. "batch_size": form_data.n,
  437. "width": width,
  438. "height": height,
  439. }
  440. if app.state.config.IMAGE_STEPS is not None:
  441. data["steps"] = app.state.config.IMAGE_STEPS
  442. if form_data.negative_prompt is not None:
  443. data["negative_prompt"] = form_data.negative_prompt
  444. r = requests.post(
  445. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
  446. json=data,
  447. headers={"authorization": get_automatic1111_api_auth()},
  448. )
  449. res = r.json()
  450. log.debug(f"res: {res}")
  451. images = []
  452. for image in res["images"]:
  453. image_filename = save_b64_image(image)
  454. images.append({"url": f"/cache/image/generations/{image_filename}"})
  455. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  456. with open(file_body_path, "w") as f:
  457. json.dump({**data, "info": res["info"]}, f)
  458. return images
  459. except Exception as e:
  460. error = e
  461. if r != None:
  462. data = r.json()
  463. if "error" in data:
  464. error = data["error"]["message"]
  465. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))