main.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586
  1. import re
  2. import requests
  3. import base64
  4. from fastapi import (
  5. FastAPI,
  6. Request,
  7. Depends,
  8. HTTPException,
  9. status,
  10. UploadFile,
  11. File,
  12. Form,
  13. )
  14. from fastapi.middleware.cors import CORSMiddleware
  15. from constants import ERROR_MESSAGES
  16. from utils.utils import (
  17. get_verified_user,
  18. get_admin_user,
  19. )
  20. from apps.images.utils.comfyui import ImageGenerationPayload, comfyui_generate_image
  21. from utils.misc import calculate_sha256
  22. from typing import Optional
  23. from pydantic import BaseModel
  24. from pathlib import Path
  25. import mimetypes
  26. import uuid
  27. import base64
  28. import json
  29. import logging
  30. from config import (
  31. SRC_LOG_LEVELS,
  32. CACHE_DIR,
  33. IMAGE_GENERATION_ENGINE,
  34. ENABLE_IMAGE_GENERATION,
  35. AUTOMATIC1111_BASE_URL,
  36. AUTOMATIC1111_API_AUTH,
  37. COMFYUI_BASE_URL,
  38. COMFYUI_CFG_SCALE,
  39. COMFYUI_SAMPLER,
  40. COMFYUI_SCHEDULER,
  41. COMFYUI_SD3,
  42. COMFYUI_FLUX,
  43. COMFYUI_FLUX_WEIGHT_DTYPE,
  44. COMFYUI_FLUX_FP8_CLIP,
  45. IMAGES_OPENAI_API_BASE_URL,
  46. IMAGES_OPENAI_API_KEY,
  47. IMAGE_GENERATION_MODEL,
  48. IMAGE_SIZE,
  49. IMAGE_STEPS,
  50. AppConfig,
  51. CORS_ALLOW_ORIGIN,
  52. )
  53. log = logging.getLogger(__name__)
  54. log.setLevel(SRC_LOG_LEVELS["IMAGES"])
  55. IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
  56. IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  57. app = FastAPI()
  58. app.add_middleware(
  59. CORSMiddleware,
  60. allow_origins=CORS_ALLOW_ORIGIN,
  61. allow_credentials=True,
  62. allow_methods=["*"],
  63. allow_headers=["*"],
  64. )
  65. app.state.config = AppConfig()
  66. app.state.config.ENGINE = IMAGE_GENERATION_ENGINE
  67. app.state.config.ENABLED = ENABLE_IMAGE_GENERATION
  68. app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
  69. app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
  70. app.state.config.MODEL = IMAGE_GENERATION_MODEL
  71. app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  72. app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
  73. app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
  74. app.state.config.IMAGE_SIZE = IMAGE_SIZE
  75. app.state.config.IMAGE_STEPS = IMAGE_STEPS
  76. app.state.config.COMFYUI_CFG_SCALE = COMFYUI_CFG_SCALE
  77. app.state.config.COMFYUI_SAMPLER = COMFYUI_SAMPLER
  78. app.state.config.COMFYUI_SCHEDULER = COMFYUI_SCHEDULER
  79. app.state.config.COMFYUI_SD3 = COMFYUI_SD3
  80. app.state.config.COMFYUI_FLUX = COMFYUI_FLUX
  81. app.state.config.COMFYUI_FLUX_WEIGHT_DTYPE = COMFYUI_FLUX_WEIGHT_DTYPE
  82. app.state.config.COMFYUI_FLUX_FP8_CLIP = COMFYUI_FLUX_FP8_CLIP
  83. def get_automatic1111_api_auth():
  84. if app.state.config.AUTOMATIC1111_API_AUTH is None:
  85. return ""
  86. else:
  87. auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8")
  88. auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
  89. auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
  90. return f"Basic {auth1111_base64_encoded_string}"
  91. @app.get("/config")
  92. async def get_config(request: Request, user=Depends(get_admin_user)):
  93. return {
  94. "engine": app.state.config.ENGINE,
  95. "enabled": app.state.config.ENABLED,
  96. }
  97. class ConfigUpdateForm(BaseModel):
  98. engine: str
  99. enabled: bool
  100. @app.post("/config/update")
  101. async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
  102. app.state.config.ENGINE = form_data.engine
  103. app.state.config.ENABLED = form_data.enabled
  104. return {
  105. "engine": app.state.config.ENGINE,
  106. "enabled": app.state.config.ENABLED,
  107. }
  108. class EngineUrlUpdateForm(BaseModel):
  109. AUTOMATIC1111_BASE_URL: Optional[str] = None
  110. AUTOMATIC1111_API_AUTH: Optional[str] = None
  111. COMFYUI_BASE_URL: Optional[str] = None
  112. @app.get("/url")
  113. async def get_engine_url(user=Depends(get_admin_user)):
  114. return {
  115. "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
  116. "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
  117. "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
  118. }
  119. @app.post("/url/update")
  120. async def update_engine_url(
  121. form_data: EngineUrlUpdateForm, user=Depends(get_admin_user)
  122. ):
  123. if form_data.AUTOMATIC1111_BASE_URL is None:
  124. app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  125. else:
  126. url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
  127. try:
  128. r = requests.head(url)
  129. r.raise_for_status()
  130. app.state.config.AUTOMATIC1111_BASE_URL = url
  131. except Exception as e:
  132. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
  133. if form_data.COMFYUI_BASE_URL is None:
  134. app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
  135. else:
  136. url = form_data.COMFYUI_BASE_URL.strip("/")
  137. try:
  138. r = requests.head(url)
  139. r.raise_for_status()
  140. app.state.config.COMFYUI_BASE_URL = url
  141. except Exception as e:
  142. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
  143. if form_data.AUTOMATIC1111_API_AUTH is None:
  144. app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
  145. else:
  146. app.state.config.AUTOMATIC1111_API_AUTH = form_data.AUTOMATIC1111_API_AUTH
  147. return {
  148. "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
  149. "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
  150. "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
  151. "status": True,
  152. }
  153. class OpenAIConfigUpdateForm(BaseModel):
  154. url: str
  155. key: str
  156. @app.get("/openai/config")
  157. async def get_openai_config(user=Depends(get_admin_user)):
  158. return {
  159. "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
  160. "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
  161. }
  162. @app.post("/openai/config/update")
  163. async def update_openai_config(
  164. form_data: OpenAIConfigUpdateForm, user=Depends(get_admin_user)
  165. ):
  166. if form_data.key == "":
  167. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
  168. app.state.config.OPENAI_API_BASE_URL = form_data.url
  169. app.state.config.OPENAI_API_KEY = form_data.key
  170. return {
  171. "status": True,
  172. "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
  173. "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
  174. }
  175. class ImageSizeUpdateForm(BaseModel):
  176. size: str
  177. @app.get("/size")
  178. async def get_image_size(user=Depends(get_admin_user)):
  179. return {"IMAGE_SIZE": app.state.config.IMAGE_SIZE}
  180. @app.post("/size/update")
  181. async def update_image_size(
  182. form_data: ImageSizeUpdateForm, user=Depends(get_admin_user)
  183. ):
  184. pattern = r"^\d+x\d+$" # Regular expression pattern
  185. if re.match(pattern, form_data.size):
  186. app.state.config.IMAGE_SIZE = form_data.size
  187. return {
  188. "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
  189. "status": True,
  190. }
  191. else:
  192. raise HTTPException(
  193. status_code=400,
  194. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
  195. )
  196. class ImageStepsUpdateForm(BaseModel):
  197. steps: int
  198. @app.get("/steps")
  199. async def get_image_size(user=Depends(get_admin_user)):
  200. return {"IMAGE_STEPS": app.state.config.IMAGE_STEPS}
  201. @app.post("/steps/update")
  202. async def update_image_size(
  203. form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
  204. ):
  205. if form_data.steps >= 0:
  206. app.state.config.IMAGE_STEPS = form_data.steps
  207. return {
  208. "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
  209. "status": True,
  210. }
  211. else:
  212. raise HTTPException(
  213. status_code=400,
  214. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
  215. )
  216. @app.get("/models")
  217. def get_models(user=Depends(get_verified_user)):
  218. try:
  219. if app.state.config.ENGINE == "openai":
  220. return [
  221. {"id": "dall-e-2", "name": "DALL·E 2"},
  222. {"id": "dall-e-3", "name": "DALL·E 3"},
  223. ]
  224. elif app.state.config.ENGINE == "comfyui":
  225. r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
  226. info = r.json()
  227. return list(
  228. map(
  229. lambda model: {"id": model, "name": model},
  230. info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0],
  231. )
  232. )
  233. else:
  234. r = requests.get(
  235. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
  236. headers={"authorization": get_automatic1111_api_auth()},
  237. )
  238. models = r.json()
  239. return list(
  240. map(
  241. lambda model: {"id": model["title"], "name": model["model_name"]},
  242. models,
  243. )
  244. )
  245. except Exception as e:
  246. app.state.config.ENABLED = False
  247. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  248. @app.get("/models/default")
  249. async def get_default_model(user=Depends(get_admin_user)):
  250. try:
  251. if app.state.config.ENGINE == "openai":
  252. return {
  253. "model": (
  254. app.state.config.MODEL if app.state.config.MODEL else "dall-e-2"
  255. )
  256. }
  257. elif app.state.config.ENGINE == "comfyui":
  258. return {"model": (app.state.config.MODEL if app.state.config.MODEL else "")}
  259. else:
  260. r = requests.get(
  261. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  262. headers={"authorization": get_automatic1111_api_auth()},
  263. )
  264. options = r.json()
  265. return {"model": options["sd_model_checkpoint"]}
  266. except Exception as e:
  267. app.state.config.ENABLED = False
  268. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  269. class UpdateModelForm(BaseModel):
  270. model: str
  271. def set_model_handler(model: str):
  272. if app.state.config.ENGINE in ["openai", "comfyui"]:
  273. app.state.config.MODEL = model
  274. return app.state.config.MODEL
  275. else:
  276. api_auth = get_automatic1111_api_auth()
  277. r = requests.get(
  278. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  279. headers={"authorization": api_auth},
  280. )
  281. options = r.json()
  282. if model != options["sd_model_checkpoint"]:
  283. options["sd_model_checkpoint"] = model
  284. r = requests.post(
  285. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  286. json=options,
  287. headers={"authorization": api_auth},
  288. )
  289. return options
  290. @app.post("/models/default/update")
  291. def update_default_model(
  292. form_data: UpdateModelForm,
  293. user=Depends(get_verified_user),
  294. ):
  295. return set_model_handler(form_data.model)
  296. class GenerateImageForm(BaseModel):
  297. model: Optional[str] = None
  298. prompt: str
  299. n: int = 1
  300. size: Optional[str] = None
  301. negative_prompt: Optional[str] = None
  302. def save_b64_image(b64_str):
  303. try:
  304. image_id = str(uuid.uuid4())
  305. if "," in b64_str:
  306. header, encoded = b64_str.split(",", 1)
  307. mime_type = header.split(";")[0]
  308. img_data = base64.b64decode(encoded)
  309. image_format = mimetypes.guess_extension(mime_type)
  310. image_filename = f"{image_id}{image_format}"
  311. file_path = IMAGE_CACHE_DIR / f"{image_filename}"
  312. with open(file_path, "wb") as f:
  313. f.write(img_data)
  314. return image_filename
  315. else:
  316. image_filename = f"{image_id}.png"
  317. file_path = IMAGE_CACHE_DIR.joinpath(image_filename)
  318. img_data = base64.b64decode(b64_str)
  319. # Write the image data to a file
  320. with open(file_path, "wb") as f:
  321. f.write(img_data)
  322. return image_filename
  323. except Exception as e:
  324. log.exception(f"Error saving image: {e}")
  325. return None
  326. def save_url_image(url):
  327. image_id = str(uuid.uuid4())
  328. try:
  329. r = requests.get(url)
  330. r.raise_for_status()
  331. if r.headers["content-type"].split("/")[0] == "image":
  332. mime_type = r.headers["content-type"]
  333. image_format = mimetypes.guess_extension(mime_type)
  334. if not image_format:
  335. raise ValueError("Could not determine image type from MIME type")
  336. image_filename = f"{image_id}{image_format}"
  337. file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
  338. with open(file_path, "wb") as image_file:
  339. for chunk in r.iter_content(chunk_size=8192):
  340. image_file.write(chunk)
  341. return image_filename
  342. else:
  343. log.error(f"Url does not point to an image.")
  344. return None
  345. except Exception as e:
  346. log.exception(f"Error saving image: {e}")
  347. return None
  348. @app.post("/generations")
  349. async def image_generations(
  350. form_data: GenerateImageForm,
  351. user=Depends(get_verified_user),
  352. ):
  353. width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x")))
  354. r = None
  355. try:
  356. if app.state.config.ENGINE == "openai":
  357. headers = {}
  358. headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
  359. headers["Content-Type"] = "application/json"
  360. data = {
  361. "model": (
  362. app.state.config.MODEL
  363. if app.state.config.MODEL != ""
  364. else "dall-e-2"
  365. ),
  366. "prompt": form_data.prompt,
  367. "n": form_data.n,
  368. "size": (
  369. form_data.size if form_data.size else app.state.config.IMAGE_SIZE
  370. ),
  371. "response_format": "b64_json",
  372. }
  373. r = requests.post(
  374. url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations",
  375. json=data,
  376. headers=headers,
  377. )
  378. r.raise_for_status()
  379. res = r.json()
  380. images = []
  381. for image in res["data"]:
  382. image_filename = save_b64_image(image["b64_json"])
  383. images.append({"url": f"/cache/image/generations/{image_filename}"})
  384. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  385. with open(file_body_path, "w") as f:
  386. json.dump(data, f)
  387. return images
  388. elif app.state.config.ENGINE == "comfyui":
  389. data = {
  390. "prompt": form_data.prompt,
  391. "width": width,
  392. "height": height,
  393. "n": form_data.n,
  394. }
  395. if app.state.config.IMAGE_STEPS is not None:
  396. data["steps"] = app.state.config.IMAGE_STEPS
  397. if form_data.negative_prompt is not None:
  398. data["negative_prompt"] = form_data.negative_prompt
  399. if app.state.config.COMFYUI_CFG_SCALE:
  400. data["cfg_scale"] = app.state.config.COMFYUI_CFG_SCALE
  401. if app.state.config.COMFYUI_SAMPLER is not None:
  402. data["sampler"] = app.state.config.COMFYUI_SAMPLER
  403. if app.state.config.COMFYUI_SCHEDULER is not None:
  404. data["scheduler"] = app.state.config.COMFYUI_SCHEDULER
  405. if app.state.config.COMFYUI_SD3 is not None:
  406. data["sd3"] = app.state.config.COMFYUI_SD3
  407. if app.state.config.COMFYUI_FLUX is not None:
  408. data["flux"] = app.state.config.COMFYUI_FLUX
  409. if app.state.config.COMFYUI_FLUX_WEIGHT_DTYPE is not None:
  410. data["flux_weight_dtype"] = app.state.config.COMFYUI_FLUX_WEIGHT_DTYPE
  411. if app.state.config.COMFYUI_FLUX_FP8_CLIP is not None:
  412. data["flux_fp8_clip"] = app.state.config.COMFYUI_FLUX_FP8_CLIP
  413. data = ImageGenerationPayload(**data)
  414. res = await comfyui_generate_image(
  415. app.state.config.MODEL,
  416. data,
  417. user.id,
  418. app.state.config.COMFYUI_BASE_URL,
  419. )
  420. log.debug(f"res: {res}")
  421. images = []
  422. for image in res["data"]:
  423. image_filename = save_url_image(image["url"])
  424. images.append({"url": f"/cache/image/generations/{image_filename}"})
  425. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  426. with open(file_body_path, "w") as f:
  427. json.dump(data.model_dump(exclude_none=True), f)
  428. log.debug(f"images: {images}")
  429. return images
  430. else:
  431. if form_data.model:
  432. set_model_handler(form_data.model)
  433. data = {
  434. "prompt": form_data.prompt,
  435. "batch_size": form_data.n,
  436. "width": width,
  437. "height": height,
  438. }
  439. if app.state.config.IMAGE_STEPS is not None:
  440. data["steps"] = app.state.config.IMAGE_STEPS
  441. if form_data.negative_prompt is not None:
  442. data["negative_prompt"] = form_data.negative_prompt
  443. r = requests.post(
  444. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
  445. json=data,
  446. headers={"authorization": get_automatic1111_api_auth()},
  447. )
  448. res = r.json()
  449. log.debug(f"res: {res}")
  450. images = []
  451. for image in res["images"]:
  452. image_filename = save_b64_image(image)
  453. images.append({"url": f"/cache/image/generations/{image_filename}"})
  454. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  455. with open(file_body_path, "w") as f:
  456. json.dump({**data, "info": res["info"]}, f)
  457. return images
  458. except Exception as e:
  459. error = e
  460. if r != None:
  461. data = r.json()
  462. if "error" in data:
  463. error = data["error"]["message"]
  464. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))