main.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565
  1. from fastapi import (
  2. FastAPI,
  3. Request,
  4. Depends,
  5. HTTPException,
  6. )
  7. from fastapi.middleware.cors import CORSMiddleware
  8. from typing import Optional
  9. from pydantic import BaseModel
  10. from pathlib import Path
  11. import mimetypes
  12. import uuid
  13. import base64
  14. import json
  15. import logging
  16. import re
  17. import requests
  18. import asyncio
  19. from utils.utils import (
  20. get_verified_user,
  21. get_admin_user,
  22. )
  23. from apps.images.utils.comfyui import (
  24. ComfyUIWorkflow,
  25. ComfyUIGenerateImageForm,
  26. comfyui_generate_image,
  27. )
  28. from constants import ERROR_MESSAGES
  29. from config import (
  30. SRC_LOG_LEVELS,
  31. CACHE_DIR,
  32. IMAGE_GENERATION_ENGINE,
  33. ENABLE_IMAGE_GENERATION,
  34. AUTOMATIC1111_BASE_URL,
  35. AUTOMATIC1111_API_AUTH,
  36. COMFYUI_BASE_URL,
  37. COMFYUI_WORKFLOW,
  38. COMFYUI_WORKFLOW_NODES,
  39. IMAGES_OPENAI_API_BASE_URL,
  40. IMAGES_OPENAI_API_KEY,
  41. IMAGE_GENERATION_MODEL,
  42. IMAGE_SIZE,
  43. IMAGE_STEPS,
  44. CORS_ALLOW_ORIGIN,
  45. AppConfig,
  46. )
  47. log = logging.getLogger(__name__)
  48. log.setLevel(SRC_LOG_LEVELS["IMAGES"])
  49. IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
  50. IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  51. app = FastAPI()
  52. app.add_middleware(
  53. CORSMiddleware,
  54. allow_origins=CORS_ALLOW_ORIGIN,
  55. allow_credentials=True,
  56. allow_methods=["*"],
  57. allow_headers=["*"],
  58. )
  59. app.state.config = AppConfig()
  60. app.state.config.ENGINE = IMAGE_GENERATION_ENGINE
  61. app.state.config.ENABLED = ENABLE_IMAGE_GENERATION
  62. app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
  63. app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
  64. app.state.config.MODEL = IMAGE_GENERATION_MODEL
  65. app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  66. app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
  67. app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
  68. app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW
  69. app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES
  70. app.state.config.IMAGE_SIZE = IMAGE_SIZE
  71. app.state.config.IMAGE_STEPS = IMAGE_STEPS
  72. @app.get("/config")
  73. async def get_config(request: Request, user=Depends(get_admin_user)):
  74. return {
  75. "enabled": app.state.config.ENABLED,
  76. "engine": app.state.config.ENGINE,
  77. "openai": {
  78. "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
  79. "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
  80. },
  81. "automatic1111": {
  82. "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
  83. "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
  84. },
  85. "comfyui": {
  86. "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
  87. "COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW,
  88. "COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES,
  89. },
  90. }
  91. class OpenAIConfigForm(BaseModel):
  92. OPENAI_API_BASE_URL: str
  93. OPENAI_API_KEY: str
  94. class Automatic1111ConfigForm(BaseModel):
  95. AUTOMATIC1111_BASE_URL: str
  96. AUTOMATIC1111_API_AUTH: str
  97. class ComfyUIConfigForm(BaseModel):
  98. COMFYUI_BASE_URL: str
  99. COMFYUI_WORKFLOW: str
  100. COMFYUI_WORKFLOW_NODES: list[dict]
  101. class ConfigForm(BaseModel):
  102. enabled: bool
  103. engine: str
  104. openai: OpenAIConfigForm
  105. automatic1111: Automatic1111ConfigForm
  106. comfyui: ComfyUIConfigForm
  107. @app.post("/config/update")
  108. async def update_config(form_data: ConfigForm, user=Depends(get_admin_user)):
  109. app.state.config.ENGINE = form_data.engine
  110. app.state.config.ENABLED = form_data.enabled
  111. app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL
  112. app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
  113. app.state.config.AUTOMATIC1111_BASE_URL = (
  114. form_data.automatic1111.AUTOMATIC1111_BASE_URL
  115. )
  116. app.state.config.AUTOMATIC1111_API_AUTH = (
  117. form_data.automatic1111.AUTOMATIC1111_API_AUTH
  118. )
  119. app.state.config.COMFYUI_BASE_URL = form_data.comfyui.COMFYUI_BASE_URL
  120. app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
  121. app.state.config.COMFYUI_WORKFLOW_NODES = form_data.comfyui.COMFYUI_WORKFLOW_NODES
  122. return {
  123. "enabled": app.state.config.ENABLED,
  124. "engine": app.state.config.ENGINE,
  125. "openai": {
  126. "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
  127. "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
  128. },
  129. "automatic1111": {
  130. "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
  131. "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
  132. },
  133. "comfyui": {
  134. "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
  135. "COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW,
  136. "COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES,
  137. },
  138. }
  139. def get_automatic1111_api_auth():
  140. if app.state.config.AUTOMATIC1111_API_AUTH is None:
  141. return ""
  142. else:
  143. auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8")
  144. auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
  145. auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
  146. return f"Basic {auth1111_base64_encoded_string}"
  147. @app.get("/config/url/verify")
  148. async def verify_url(user=Depends(get_admin_user)):
  149. if app.state.config.ENGINE == "automatic1111":
  150. try:
  151. r = requests.get(
  152. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  153. headers={"authorization": get_automatic1111_api_auth()},
  154. )
  155. r.raise_for_status()
  156. return True
  157. except Exception as e:
  158. app.state.config.ENABLED = False
  159. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
  160. elif app.state.config.ENGINE == "comfyui":
  161. try:
  162. r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
  163. r.raise_for_status()
  164. return True
  165. except Exception as e:
  166. app.state.config.ENABLED = False
  167. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
  168. else:
  169. return True
  170. def set_image_model(model: str):
  171. app.state.config.MODEL = model
  172. if app.state.config.ENGINE in ["", "automatic1111"]:
  173. api_auth = get_automatic1111_api_auth()
  174. r = requests.get(
  175. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  176. headers={"authorization": api_auth},
  177. )
  178. options = r.json()
  179. if model != options["sd_model_checkpoint"]:
  180. options["sd_model_checkpoint"] = model
  181. r = requests.post(
  182. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  183. json=options,
  184. headers={"authorization": api_auth},
  185. )
  186. return app.state.config.MODEL
  187. def get_image_model():
  188. if app.state.config.ENGINE == "openai":
  189. return app.state.config.MODEL if app.state.config.MODEL else "dall-e-2"
  190. elif app.state.config.ENGINE == "comfyui":
  191. return app.state.config.MODEL if app.state.config.MODEL else ""
  192. elif app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "":
  193. try:
  194. r = requests.get(
  195. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  196. headers={"authorization": get_automatic1111_api_auth()},
  197. )
  198. options = r.json()
  199. return options["sd_model_checkpoint"]
  200. except Exception as e:
  201. app.state.config.ENABLED = False
  202. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  203. class ImageConfigForm(BaseModel):
  204. MODEL: str
  205. IMAGE_SIZE: str
  206. IMAGE_STEPS: int
  207. @app.get("/image/config")
  208. async def get_image_config(user=Depends(get_admin_user)):
  209. return {
  210. "MODEL": app.state.config.MODEL,
  211. "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
  212. "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
  213. }
  214. @app.post("/image/config/update")
  215. async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin_user)):
  216. app.state.config.MODEL = form_data.MODEL
  217. pattern = r"^\d+x\d+$"
  218. if re.match(pattern, form_data.IMAGE_SIZE):
  219. app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
  220. else:
  221. raise HTTPException(
  222. status_code=400,
  223. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
  224. )
  225. if form_data.IMAGE_STEPS >= 0:
  226. app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS
  227. else:
  228. raise HTTPException(
  229. status_code=400,
  230. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
  231. )
  232. return {
  233. "MODEL": app.state.config.MODEL,
  234. "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
  235. "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
  236. }
  237. @app.get("/models")
  238. def get_models(user=Depends(get_verified_user)):
  239. try:
  240. if app.state.config.ENGINE == "openai":
  241. return [
  242. {"id": "dall-e-2", "name": "DALL·E 2"},
  243. {"id": "dall-e-3", "name": "DALL·E 3"},
  244. ]
  245. elif app.state.config.ENGINE == "comfyui":
  246. # TODO - get models from comfyui
  247. r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
  248. info = r.json()
  249. workflow = json.loads(app.state.config.COMFYUI_WORKFLOW)
  250. model_node_id = None
  251. for node in app.state.config.COMFYUI_WORKFLOW_NODES:
  252. if node["type"] == "model":
  253. if node["node_ids"]:
  254. model_node_id = node["node_ids"][0]
  255. break
  256. if model_node_id:
  257. model_list_key = None
  258. print(workflow[model_node_id]["class_type"])
  259. for key in info[workflow[model_node_id]["class_type"]]["input"][
  260. "required"
  261. ]:
  262. if "_name" in key:
  263. model_list_key = key
  264. break
  265. if model_list_key:
  266. return list(
  267. map(
  268. lambda model: {"id": model, "name": model},
  269. info[workflow[model_node_id]["class_type"]]["input"][
  270. "required"
  271. ][model_list_key][0],
  272. )
  273. )
  274. else:
  275. return list(
  276. map(
  277. lambda model: {"id": model, "name": model},
  278. info["CheckpointLoaderSimple"]["input"]["required"][
  279. "ckpt_name"
  280. ][0],
  281. )
  282. )
  283. elif (
  284. app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == ""
  285. ):
  286. r = requests.get(
  287. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
  288. headers={"authorization": get_automatic1111_api_auth()},
  289. )
  290. models = r.json()
  291. return list(
  292. map(
  293. lambda model: {"id": model["title"], "name": model["model_name"]},
  294. models,
  295. )
  296. )
  297. except Exception as e:
  298. app.state.config.ENABLED = False
  299. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  300. class GenerateImageForm(BaseModel):
  301. model: Optional[str] = None
  302. prompt: str
  303. size: Optional[str] = None
  304. n: int = 1
  305. negative_prompt: Optional[str] = None
  306. def save_b64_image(b64_str):
  307. try:
  308. image_id = str(uuid.uuid4())
  309. if "," in b64_str:
  310. header, encoded = b64_str.split(",", 1)
  311. mime_type = header.split(";")[0]
  312. img_data = base64.b64decode(encoded)
  313. image_format = mimetypes.guess_extension(mime_type)
  314. image_filename = f"{image_id}{image_format}"
  315. file_path = IMAGE_CACHE_DIR / f"{image_filename}"
  316. with open(file_path, "wb") as f:
  317. f.write(img_data)
  318. return image_filename
  319. else:
  320. image_filename = f"{image_id}.png"
  321. file_path = IMAGE_CACHE_DIR.joinpath(image_filename)
  322. img_data = base64.b64decode(b64_str)
  323. # Write the image data to a file
  324. with open(file_path, "wb") as f:
  325. f.write(img_data)
  326. return image_filename
  327. except Exception as e:
  328. log.exception(f"Error saving image: {e}")
  329. return None
  330. def save_url_image(url):
  331. image_id = str(uuid.uuid4())
  332. try:
  333. r = requests.get(url)
  334. r.raise_for_status()
  335. if r.headers["content-type"].split("/")[0] == "image":
  336. mime_type = r.headers["content-type"]
  337. image_format = mimetypes.guess_extension(mime_type)
  338. if not image_format:
  339. raise ValueError("Could not determine image type from MIME type")
  340. image_filename = f"{image_id}{image_format}"
  341. file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
  342. with open(file_path, "wb") as image_file:
  343. for chunk in r.iter_content(chunk_size=8192):
  344. image_file.write(chunk)
  345. return image_filename
  346. else:
  347. log.error(f"Url does not point to an image.")
  348. return None
  349. except Exception as e:
  350. log.exception(f"Error saving image: {e}")
  351. return None
  352. @app.post("/generations")
  353. async def image_generations(
  354. form_data: GenerateImageForm,
  355. user=Depends(get_verified_user),
  356. ):
  357. width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x")))
  358. r = None
  359. try:
  360. if app.state.config.ENGINE == "openai":
  361. headers = {}
  362. headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
  363. headers["Content-Type"] = "application/json"
  364. data = {
  365. "model": (
  366. app.state.config.MODEL
  367. if app.state.config.MODEL != ""
  368. else "dall-e-2"
  369. ),
  370. "prompt": form_data.prompt,
  371. "n": form_data.n,
  372. "size": (
  373. form_data.size if form_data.size else app.state.config.IMAGE_SIZE
  374. ),
  375. "response_format": "b64_json",
  376. }
  377. r = requests.post(
  378. url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations",
  379. json=data,
  380. headers=headers,
  381. )
  382. r.raise_for_status()
  383. res = r.json()
  384. images = []
  385. for image in res["data"]:
  386. image_filename = save_b64_image(image["b64_json"])
  387. images.append({"url": f"/cache/image/generations/{image_filename}"})
  388. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  389. with open(file_body_path, "w") as f:
  390. json.dump(data, f)
  391. return images
  392. elif app.state.config.ENGINE == "comfyui":
  393. data = {
  394. "prompt": form_data.prompt,
  395. "width": width,
  396. "height": height,
  397. "n": form_data.n,
  398. }
  399. if app.state.config.IMAGE_STEPS is not None:
  400. data["steps"] = app.state.config.IMAGE_STEPS
  401. if form_data.negative_prompt is not None:
  402. data["negative_prompt"] = form_data.negative_prompt
  403. form_data = ComfyUIGenerateImageForm(
  404. **{
  405. "workflow": ComfyUIWorkflow(
  406. **{
  407. "workflow": app.state.config.COMFYUI_WORKFLOW,
  408. "nodes": app.state.config.COMFYUI_WORKFLOW_NODES,
  409. }
  410. ),
  411. **data,
  412. }
  413. )
  414. res = await comfyui_generate_image(
  415. app.state.config.MODEL,
  416. form_data,
  417. user.id,
  418. app.state.config.COMFYUI_BASE_URL,
  419. )
  420. log.debug(f"res: {res}")
  421. images = []
  422. for image in res["data"]:
  423. image_filename = save_url_image(image["url"])
  424. images.append({"url": f"/cache/image/generations/{image_filename}"})
  425. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  426. with open(file_body_path, "w") as f:
  427. json.dump(form_data.model_dump(exclude_none=True), f)
  428. log.debug(f"images: {images}")
  429. return images
  430. elif (
  431. app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == ""
  432. ):
  433. if form_data.model:
  434. set_image_model(form_data.model)
  435. data = {
  436. "prompt": form_data.prompt,
  437. "batch_size": form_data.n,
  438. "width": width,
  439. "height": height,
  440. }
  441. if app.state.config.IMAGE_STEPS is not None:
  442. data["steps"] = app.state.config.IMAGE_STEPS
  443. if form_data.negative_prompt is not None:
  444. data["negative_prompt"] = form_data.negative_prompt
  445. # Use asyncio.to_thread for the requests.post call
  446. r = await asyncio.to_thread(
  447. requests.post,
  448. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
  449. json=data,
  450. headers={"authorization": get_automatic1111_api_auth()},
  451. )
  452. res = r.json()
  453. log.debug(f"res: {res}")
  454. images = []
  455. for image in res["images"]:
  456. image_filename = save_b64_image(image)
  457. images.append({"url": f"/cache/image/generations/{image_filename}"})
  458. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  459. with open(file_body_path, "w") as f:
  460. json.dump({**data, "info": res["info"]}, f)
  461. return images
  462. except Exception as e:
  463. error = e
  464. if r != None:
  465. data = r.json()
  466. if "error" in data:
  467. error = data["error"]["message"]
  468. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))