main.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547
  1. import re
  2. import requests
  3. from fastapi import (
  4. FastAPI,
  5. Request,
  6. Depends,
  7. HTTPException,
  8. status,
  9. UploadFile,
  10. File,
  11. Form,
  12. )
  13. from fastapi.middleware.cors import CORSMiddleware
  14. from faster_whisper import WhisperModel
  15. from constants import ERROR_MESSAGES
  16. from utils.utils import (
  17. get_current_user,
  18. get_admin_user,
  19. )
  20. from apps.images.utils.comfyui import ImageGenerationPayload, comfyui_generate_image
  21. from utils.misc import calculate_sha256
  22. from typing import Optional
  23. from pydantic import BaseModel
  24. from pathlib import Path
  25. import mimetypes
  26. import uuid
  27. import base64
  28. import json
  29. import logging
  30. from config import (
  31. SRC_LOG_LEVELS,
  32. CACHE_DIR,
  33. IMAGE_GENERATION_ENGINE,
  34. ENABLE_IMAGE_GENERATION,
  35. AUTOMATIC1111_BASE_URL,
  36. COMFYUI_BASE_URL,
  37. COMFYUI_CFG_SCALE,
  38. COMFYUI_SAMPLER,
  39. COMFYUI_SCHEDULER,
  40. COMFYUI_SD3,
  41. IMAGES_OPENAI_API_BASE_URL,
  42. IMAGES_OPENAI_API_KEY,
  43. IMAGE_GENERATION_MODEL,
  44. IMAGE_SIZE,
  45. IMAGE_STEPS,
  46. AppConfig,
  47. )
  48. log = logging.getLogger(__name__)
  49. log.setLevel(SRC_LOG_LEVELS["IMAGES"])
  50. IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
  51. IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  52. app = FastAPI()
  53. app.add_middleware(
  54. CORSMiddleware,
  55. allow_origins=["*"],
  56. allow_credentials=True,
  57. allow_methods=["*"],
  58. allow_headers=["*"],
  59. )
  60. app.state.config = AppConfig()
  61. app.state.config.ENGINE = IMAGE_GENERATION_ENGINE
  62. app.state.config.ENABLED = ENABLE_IMAGE_GENERATION
  63. app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
  64. app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
  65. app.state.config.MODEL = IMAGE_GENERATION_MODEL
  66. app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  67. app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
  68. app.state.config.IMAGE_SIZE = IMAGE_SIZE
  69. app.state.config.IMAGE_STEPS = IMAGE_STEPS
  70. app.state.config.COMFYUI_CFG_SCALE = COMFYUI_CFG_SCALE
  71. app.state.config.COMFYUI_SAMPLER = COMFYUI_SAMPLER
  72. app.state.config.COMFYUI_SCHEDULER = COMFYUI_SCHEDULER
  73. app.state.config.COMFYUI_SD3 = COMFYUI_SD3
  74. @app.get("/config")
  75. async def get_config(request: Request, user=Depends(get_admin_user)):
  76. return {
  77. "engine": app.state.config.ENGINE,
  78. "enabled": app.state.config.ENABLED,
  79. }
  80. class ConfigUpdateForm(BaseModel):
  81. engine: str
  82. enabled: bool
  83. @app.post("/config/update")
  84. async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
  85. app.state.config.ENGINE = form_data.engine
  86. app.state.config.ENABLED = form_data.enabled
  87. return {
  88. "engine": app.state.config.ENGINE,
  89. "enabled": app.state.config.ENABLED,
  90. }
  91. class EngineUrlUpdateForm(BaseModel):
  92. AUTOMATIC1111_BASE_URL: Optional[str] = None
  93. COMFYUI_BASE_URL: Optional[str] = None
  94. @app.get("/url")
  95. async def get_engine_url(user=Depends(get_admin_user)):
  96. return {
  97. "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
  98. "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
  99. }
  100. @app.post("/url/update")
  101. async def update_engine_url(
  102. form_data: EngineUrlUpdateForm, user=Depends(get_admin_user)
  103. ):
  104. if form_data.AUTOMATIC1111_BASE_URL == None:
  105. app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  106. else:
  107. url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
  108. try:
  109. r = requests.head(url)
  110. app.state.config.AUTOMATIC1111_BASE_URL = url
  111. except Exception as e:
  112. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  113. if form_data.COMFYUI_BASE_URL == None:
  114. app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
  115. else:
  116. url = form_data.COMFYUI_BASE_URL.strip("/")
  117. try:
  118. r = requests.head(url)
  119. app.state.config.COMFYUI_BASE_URL = url
  120. except Exception as e:
  121. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  122. return {
  123. "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
  124. "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
  125. "status": True,
  126. }
  127. class OpenAIConfigUpdateForm(BaseModel):
  128. url: str
  129. key: str
  130. @app.get("/openai/config")
  131. async def get_openai_config(user=Depends(get_admin_user)):
  132. return {
  133. "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
  134. "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
  135. }
  136. @app.post("/openai/config/update")
  137. async def update_openai_config(
  138. form_data: OpenAIConfigUpdateForm, user=Depends(get_admin_user)
  139. ):
  140. if form_data.key == "":
  141. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
  142. app.state.config.OPENAI_API_BASE_URL = form_data.url
  143. app.state.config.OPENAI_API_KEY = form_data.key
  144. return {
  145. "status": True,
  146. "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
  147. "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
  148. }
  149. class ImageSizeUpdateForm(BaseModel):
  150. size: str
  151. @app.get("/size")
  152. async def get_image_size(user=Depends(get_admin_user)):
  153. return {"IMAGE_SIZE": app.state.config.IMAGE_SIZE}
  154. @app.post("/size/update")
  155. async def update_image_size(
  156. form_data: ImageSizeUpdateForm, user=Depends(get_admin_user)
  157. ):
  158. pattern = r"^\d+x\d+$" # Regular expression pattern
  159. if re.match(pattern, form_data.size):
  160. app.state.config.IMAGE_SIZE = form_data.size
  161. return {
  162. "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
  163. "status": True,
  164. }
  165. else:
  166. raise HTTPException(
  167. status_code=400,
  168. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
  169. )
  170. class ImageStepsUpdateForm(BaseModel):
  171. steps: int
  172. @app.get("/steps")
  173. async def get_image_size(user=Depends(get_admin_user)):
  174. return {"IMAGE_STEPS": app.state.config.IMAGE_STEPS}
  175. @app.post("/steps/update")
  176. async def update_image_size(
  177. form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
  178. ):
  179. if form_data.steps >= 0:
  180. app.state.config.IMAGE_STEPS = form_data.steps
  181. return {
  182. "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
  183. "status": True,
  184. }
  185. else:
  186. raise HTTPException(
  187. status_code=400,
  188. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
  189. )
  190. @app.get("/models")
  191. def get_models(user=Depends(get_current_user)):
  192. try:
  193. if app.state.config.ENGINE == "openai":
  194. return [
  195. {"id": "dall-e-2", "name": "DALL·E 2"},
  196. {"id": "dall-e-3", "name": "DALL·E 3"},
  197. ]
  198. elif app.state.config.ENGINE == "comfyui":
  199. r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
  200. info = r.json()
  201. return list(
  202. map(
  203. lambda model: {"id": model, "name": model},
  204. info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0],
  205. )
  206. )
  207. else:
  208. r = requests.get(
  209. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
  210. )
  211. models = r.json()
  212. return list(
  213. map(
  214. lambda model: {"id": model["title"], "name": model["model_name"]},
  215. models,
  216. )
  217. )
  218. except Exception as e:
  219. app.state.config.ENABLED = False
  220. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  221. @app.get("/models/default")
  222. async def get_default_model(user=Depends(get_admin_user)):
  223. try:
  224. if app.state.config.ENGINE == "openai":
  225. return {
  226. "model": (
  227. app.state.config.MODEL if app.state.config.MODEL else "dall-e-2"
  228. )
  229. }
  230. elif app.state.config.ENGINE == "comfyui":
  231. return {"model": (app.state.config.MODEL if app.state.config.MODEL else "")}
  232. else:
  233. r = requests.get(
  234. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options"
  235. )
  236. options = r.json()
  237. return {"model": options["sd_model_checkpoint"]}
  238. except Exception as e:
  239. app.state.config.ENABLED = False
  240. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  241. class UpdateModelForm(BaseModel):
  242. model: str
  243. def set_model_handler(model: str):
  244. if app.state.config.ENGINE in ["openai", "comfyui"]:
  245. app.state.config.MODEL = model
  246. return app.state.config.MODEL
  247. else:
  248. r = requests.get(
  249. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options"
  250. )
  251. options = r.json()
  252. if model != options["sd_model_checkpoint"]:
  253. options["sd_model_checkpoint"] = model
  254. r = requests.post(
  255. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  256. json=options,
  257. )
  258. return options
  259. @app.post("/models/default/update")
  260. def update_default_model(
  261. form_data: UpdateModelForm,
  262. user=Depends(get_current_user),
  263. ):
  264. return set_model_handler(form_data.model)
  265. class GenerateImageForm(BaseModel):
  266. model: Optional[str] = None
  267. prompt: str
  268. n: int = 1
  269. size: Optional[str] = None
  270. negative_prompt: Optional[str] = None
  271. def save_b64_image(b64_str):
  272. try:
  273. image_id = str(uuid.uuid4())
  274. if "," in b64_str:
  275. header, encoded = b64_str.split(",", 1)
  276. mime_type = header.split(";")[0]
  277. img_data = base64.b64decode(encoded)
  278. image_format = mimetypes.guess_extension(mime_type)
  279. image_filename = f"{image_id}{image_format}"
  280. file_path = IMAGE_CACHE_DIR / f"{image_filename}"
  281. with open(file_path, "wb") as f:
  282. f.write(img_data)
  283. return image_filename
  284. else:
  285. image_filename = f"{image_id}.png"
  286. file_path = IMAGE_CACHE_DIR.joinpath(image_filename)
  287. img_data = base64.b64decode(b64_str)
  288. # Write the image data to a file
  289. with open(file_path, "wb") as f:
  290. f.write(img_data)
  291. return image_filename
  292. except Exception as e:
  293. log.exception(f"Error saving image: {e}")
  294. return None
  295. def save_url_image(url):
  296. image_id = str(uuid.uuid4())
  297. try:
  298. r = requests.get(url)
  299. r.raise_for_status()
  300. if r.headers["content-type"].split("/")[0] == "image":
  301. mime_type = r.headers["content-type"]
  302. image_format = mimetypes.guess_extension(mime_type)
  303. if not image_format:
  304. raise ValueError("Could not determine image type from MIME type")
  305. image_filename = f"{image_id}{image_format}"
  306. file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
  307. with open(file_path, "wb") as image_file:
  308. for chunk in r.iter_content(chunk_size=8192):
  309. image_file.write(chunk)
  310. return image_filename
  311. else:
  312. log.error(f"Url does not point to an image.")
  313. return None
  314. except Exception as e:
  315. log.exception(f"Error saving image: {e}")
  316. return None
  317. @app.post("/generations")
  318. def generate_image(
  319. form_data: GenerateImageForm,
  320. user=Depends(get_current_user),
  321. ):
  322. width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x")))
  323. r = None
  324. try:
  325. if app.state.config.ENGINE == "openai":
  326. headers = {}
  327. headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
  328. headers["Content-Type"] = "application/json"
  329. data = {
  330. "model": (
  331. app.state.config.MODEL
  332. if app.state.config.MODEL != ""
  333. else "dall-e-2"
  334. ),
  335. "prompt": form_data.prompt,
  336. "n": form_data.n,
  337. "size": (
  338. form_data.size if form_data.size else app.state.config.IMAGE_SIZE
  339. ),
  340. "response_format": "b64_json",
  341. }
  342. r = requests.post(
  343. url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations",
  344. json=data,
  345. headers=headers,
  346. )
  347. r.raise_for_status()
  348. res = r.json()
  349. images = []
  350. for image in res["data"]:
  351. image_filename = save_b64_image(image["b64_json"])
  352. images.append({"url": f"/cache/image/generations/{image_filename}"})
  353. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  354. with open(file_body_path, "w") as f:
  355. json.dump(data, f)
  356. return images
  357. elif app.state.config.ENGINE == "comfyui":
  358. data = {
  359. "prompt": form_data.prompt,
  360. "width": width,
  361. "height": height,
  362. "n": form_data.n,
  363. }
  364. if app.state.config.IMAGE_STEPS is not None:
  365. data["steps"] = app.state.config.IMAGE_STEPS
  366. if form_data.negative_prompt is not None:
  367. data["negative_prompt"] = form_data.negative_prompt
  368. if app.state.config.COMFYUI_CFG_SCALE:
  369. data["cfg_scale"] = app.state.config.COMFYUI_CFG_SCALE
  370. if app.state.config.COMFYUI_SAMPLER is not None:
  371. data["sampler"] = app.state.config.COMFYUI_SAMPLER
  372. if app.state.config.COMFYUI_SCHEDULER is not None:
  373. data["scheduler"] = app.state.config.COMFYUI_SCHEDULER
  374. if app.state.config.COMFYUI_SD3 is not None:
  375. data["sd3"] = app.state.config.COMFYUI_SD3
  376. data = ImageGenerationPayload(**data)
  377. res = comfyui_generate_image(
  378. app.state.config.MODEL,
  379. data,
  380. user.id,
  381. app.state.config.COMFYUI_BASE_URL,
  382. )
  383. log.debug(f"res: {res}")
  384. images = []
  385. for image in res["data"]:
  386. image_filename = save_url_image(image["url"])
  387. images.append({"url": f"/cache/image/generations/{image_filename}"})
  388. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  389. with open(file_body_path, "w") as f:
  390. json.dump(data.model_dump(exclude_none=True), f)
  391. log.debug(f"images: {images}")
  392. return images
  393. else:
  394. if form_data.model:
  395. set_model_handler(form_data.model)
  396. data = {
  397. "prompt": form_data.prompt,
  398. "batch_size": form_data.n,
  399. "width": width,
  400. "height": height,
  401. }
  402. if app.state.config.IMAGE_STEPS is not None:
  403. data["steps"] = app.state.config.IMAGE_STEPS
  404. if form_data.negative_prompt is not None:
  405. data["negative_prompt"] = form_data.negative_prompt
  406. r = requests.post(
  407. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
  408. json=data,
  409. )
  410. res = r.json()
  411. log.debug(f"res: {res}")
  412. images = []
  413. for image in res["images"]:
  414. image_filename = save_b64_image(image)
  415. images.append({"url": f"/cache/image/generations/{image_filename}"})
  416. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  417. with open(file_body_path, "w") as f:
  418. json.dump({**data, "info": res["info"]}, f)
  419. return images
  420. except Exception as e:
  421. error = e
  422. if r != None:
  423. data = r.json()
  424. if "error" in data:
  425. error = data["error"]["message"]
  426. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))