main.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. import re
  2. import requests
  3. from fastapi import (
  4. FastAPI,
  5. Request,
  6. Depends,
  7. HTTPException,
  8. status,
  9. UploadFile,
  10. File,
  11. Form,
  12. )
  13. from fastapi.middleware.cors import CORSMiddleware
  14. from faster_whisper import WhisperModel
  15. from constants import ERROR_MESSAGES
  16. from utils.utils import (
  17. get_current_user,
  18. get_admin_user,
  19. )
  20. from apps.images.utils.comfyui import ImageGenerationPayload, comfyui_generate_image
  21. from utils.misc import calculate_sha256
  22. from typing import Optional
  23. from pydantic import BaseModel
  24. from pathlib import Path
  25. import mimetypes
  26. import uuid
  27. import base64
  28. import json
  29. import logging
  30. from config import (
  31. SRC_LOG_LEVELS,
  32. CACHE_DIR,
  33. IMAGE_GENERATION_ENGINE,
  34. ENABLE_IMAGE_GENERATION,
  35. AUTOMATIC1111_BASE_URL,
  36. COMFYUI_BASE_URL,
  37. IMAGES_OPENAI_API_BASE_URL,
  38. IMAGES_OPENAI_API_KEY,
  39. IMAGE_GENERATION_MODEL,
  40. IMAGE_SIZE,
  41. IMAGE_STEPS,
  42. )
  43. log = logging.getLogger(__name__)
  44. log.setLevel(SRC_LOG_LEVELS["IMAGES"])
  45. IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
  46. IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  47. app = FastAPI()
  48. app.add_middleware(
  49. CORSMiddleware,
  50. allow_origins=["*"],
  51. allow_credentials=True,
  52. allow_methods=["*"],
  53. allow_headers=["*"],
  54. )
  55. app.state.ENGINE = IMAGE_GENERATION_ENGINE
  56. app.state.ENABLED = ENABLE_IMAGE_GENERATION
  57. app.state.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
  58. app.state.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
  59. app.state.MODEL = IMAGE_GENERATION_MODEL
  60. app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  61. app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
  62. app.state.IMAGE_SIZE = IMAGE_SIZE
  63. app.state.IMAGE_STEPS = IMAGE_STEPS
  64. @app.get("/config")
  65. async def get_config(request: Request, user=Depends(get_admin_user)):
  66. return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
  67. class ConfigUpdateForm(BaseModel):
  68. engine: str
  69. enabled: bool
  70. @app.post("/config/update")
  71. async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
  72. app.state.ENGINE = form_data.engine
  73. app.state.ENABLED = form_data.enabled
  74. return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
  75. class EngineUrlUpdateForm(BaseModel):
  76. AUTOMATIC1111_BASE_URL: Optional[str] = None
  77. COMFYUI_BASE_URL: Optional[str] = None
  78. @app.get("/url")
  79. async def get_engine_url(user=Depends(get_admin_user)):
  80. return {
  81. "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
  82. "COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL,
  83. }
  84. @app.post("/url/update")
  85. async def update_engine_url(
  86. form_data: EngineUrlUpdateForm, user=Depends(get_admin_user)
  87. ):
  88. if form_data.AUTOMATIC1111_BASE_URL == None:
  89. app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  90. else:
  91. url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
  92. try:
  93. r = requests.head(url)
  94. app.state.AUTOMATIC1111_BASE_URL = url
  95. except Exception as e:
  96. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  97. if form_data.COMFYUI_BASE_URL == None:
  98. app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
  99. else:
  100. url = form_data.COMFYUI_BASE_URL.strip("/")
  101. try:
  102. r = requests.head(url)
  103. app.state.COMFYUI_BASE_URL = url
  104. except Exception as e:
  105. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  106. return {
  107. "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
  108. "COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL,
  109. "status": True,
  110. }
  111. class OpenAIConfigUpdateForm(BaseModel):
  112. url: str
  113. key: str
  114. @app.get("/openai/config")
  115. async def get_openai_config(user=Depends(get_admin_user)):
  116. return {
  117. "OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
  118. "OPENAI_API_KEY": app.state.OPENAI_API_KEY,
  119. }
  120. @app.post("/openai/config/update")
  121. async def update_openai_config(
  122. form_data: OpenAIConfigUpdateForm, user=Depends(get_admin_user)
  123. ):
  124. if form_data.key == "":
  125. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
  126. app.state.OPENAI_API_BASE_URL = form_data.url
  127. app.state.OPENAI_API_KEY = form_data.key
  128. return {
  129. "status": True,
  130. "OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
  131. "OPENAI_API_KEY": app.state.OPENAI_API_KEY,
  132. }
  133. class ImageSizeUpdateForm(BaseModel):
  134. size: str
  135. @app.get("/size")
  136. async def get_image_size(user=Depends(get_admin_user)):
  137. return {"IMAGE_SIZE": app.state.IMAGE_SIZE}
  138. @app.post("/size/update")
  139. async def update_image_size(
  140. form_data: ImageSizeUpdateForm, user=Depends(get_admin_user)
  141. ):
  142. pattern = r"^\d+x\d+$" # Regular expression pattern
  143. if re.match(pattern, form_data.size):
  144. app.state.IMAGE_SIZE = form_data.size
  145. return {
  146. "IMAGE_SIZE": app.state.IMAGE_SIZE,
  147. "status": True,
  148. }
  149. else:
  150. raise HTTPException(
  151. status_code=400,
  152. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
  153. )
  154. class ImageStepsUpdateForm(BaseModel):
  155. steps: int
  156. @app.get("/steps")
  157. async def get_image_size(user=Depends(get_admin_user)):
  158. return {"IMAGE_STEPS": app.state.IMAGE_STEPS}
  159. @app.post("/steps/update")
  160. async def update_image_size(
  161. form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
  162. ):
  163. if form_data.steps >= 0:
  164. app.state.IMAGE_STEPS = form_data.steps
  165. return {
  166. "IMAGE_STEPS": app.state.IMAGE_STEPS,
  167. "status": True,
  168. }
  169. else:
  170. raise HTTPException(
  171. status_code=400,
  172. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
  173. )
  174. @app.get("/models")
  175. def get_models(user=Depends(get_current_user)):
  176. try:
  177. if app.state.ENGINE == "openai":
  178. return [
  179. {"id": "dall-e-2", "name": "DALL·E 2"},
  180. {"id": "dall-e-3", "name": "DALL·E 3"},
  181. ]
  182. elif app.state.ENGINE == "comfyui":
  183. r = requests.get(url=f"{app.state.COMFYUI_BASE_URL}/object_info")
  184. info = r.json()
  185. return list(
  186. map(
  187. lambda model: {"id": model, "name": model},
  188. info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0],
  189. )
  190. )
  191. else:
  192. r = requests.get(
  193. url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
  194. )
  195. models = r.json()
  196. return list(
  197. map(
  198. lambda model: {"id": model["title"], "name": model["model_name"]},
  199. models,
  200. )
  201. )
  202. except Exception as e:
  203. app.state.ENABLED = False
  204. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  205. @app.get("/models/default")
  206. async def get_default_model(user=Depends(get_admin_user)):
  207. try:
  208. if app.state.ENGINE == "openai":
  209. return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"}
  210. elif app.state.ENGINE == "comfyui":
  211. return {"model": app.state.MODEL if app.state.MODEL else ""}
  212. else:
  213. r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
  214. options = r.json()
  215. return {"model": options["sd_model_checkpoint"]}
  216. except Exception as e:
  217. app.state.ENABLED = False
  218. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  219. class UpdateModelForm(BaseModel):
  220. model: str
  221. def set_model_handler(model: str):
  222. if app.state.ENGINE == "openai":
  223. app.state.MODEL = model
  224. return app.state.MODEL
  225. if app.state.ENGINE == "comfyui":
  226. app.state.MODEL = model
  227. return app.state.MODEL
  228. else:
  229. r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
  230. options = r.json()
  231. if model != options["sd_model_checkpoint"]:
  232. options["sd_model_checkpoint"] = model
  233. r = requests.post(
  234. url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options
  235. )
  236. return options
  237. @app.post("/models/default/update")
  238. def update_default_model(
  239. form_data: UpdateModelForm,
  240. user=Depends(get_current_user),
  241. ):
  242. return set_model_handler(form_data.model)
  243. class GenerateImageForm(BaseModel):
  244. model: Optional[str] = None
  245. prompt: str
  246. n: int = 1
  247. size: Optional[str] = None
  248. negative_prompt: Optional[str] = None
  249. def save_b64_image(b64_str):
  250. try:
  251. header, encoded = b64_str.split(",", 1)
  252. mime_type = header.split(";")[0]
  253. img_data = base64.b64decode(encoded)
  254. image_id = str(uuid.uuid4())
  255. image_format = mimetypes.guess_extension(mime_type)
  256. image_filename = f"{image_id}{image_format}"
  257. file_path = IMAGE_CACHE_DIR / f"{image_filename}"
  258. with open(file_path, "wb") as f:
  259. f.write(img_data)
  260. return image_filename
  261. except Exception as e:
  262. log.exception(f"Error saving image: {e}")
  263. return None
  264. def save_url_image(url):
  265. image_id = str(uuid.uuid4())
  266. try:
  267. r = requests.get(url)
  268. r.raise_for_status()
  269. if r.headers["content-type"].split("/")[0] == "image":
  270. mime_type = r.headers["content-type"]
  271. image_format = mimetypes.guess_extension(mime_type)
  272. if not image_format:
  273. raise ValueError("Could not determine image type from MIME type")
  274. file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}{image_format}")
  275. with open(file_path, "wb") as image_file:
  276. for chunk in r.iter_content(chunk_size=8192):
  277. image_file.write(chunk)
  278. return image_id, image_format
  279. else:
  280. log.error(f"Url does not point to an image.")
  281. return None, None
  282. except Exception as e:
  283. log.exception(f"Error saving image: {e}")
  284. return None, None
  285. @app.post("/generations")
  286. def generate_image(
  287. form_data: GenerateImageForm,
  288. user=Depends(get_current_user),
  289. ):
  290. width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
  291. r = None
  292. try:
  293. if app.state.ENGINE == "openai":
  294. headers = {}
  295. headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
  296. headers["Content-Type"] = "application/json"
  297. data = {
  298. "model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2",
  299. "prompt": form_data.prompt,
  300. "n": form_data.n,
  301. "size": form_data.size if form_data.size else app.state.IMAGE_SIZE,
  302. "response_format": "b64_json",
  303. }
  304. r = requests.post(
  305. url=f"{app.state.OPENAI_API_BASE_URL}/images/generations",
  306. json=data,
  307. headers=headers,
  308. )
  309. r.raise_for_status()
  310. res = r.json()
  311. images = []
  312. for image in res["data"]:
  313. image_filename = save_b64_image(image["b64_json"])
  314. images.append({"url": f"/cache/image/generations/{image_filename}"})
  315. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
  316. with open(file_body_path, "w") as f:
  317. json.dump(data, f)
  318. return images
  319. elif app.state.ENGINE == "comfyui":
  320. data = {
  321. "prompt": form_data.prompt,
  322. "width": width,
  323. "height": height,
  324. "n": form_data.n,
  325. }
  326. if app.state.IMAGE_STEPS != None:
  327. data["steps"] = app.state.IMAGE_STEPS
  328. if form_data.negative_prompt != None:
  329. data["negative_prompt"] = form_data.negative_prompt
  330. data = ImageGenerationPayload(**data)
  331. res = comfyui_generate_image(
  332. app.state.MODEL,
  333. data,
  334. user.id,
  335. app.state.COMFYUI_BASE_URL,
  336. )
  337. log.debug(f"res: {res}")
  338. images = []
  339. for image in res["data"]:
  340. image_id, image_format = save_url_image(image["url"])
  341. images.append(
  342. {"url": f"/cache/image/generations/{image_id}{image_format}"}
  343. )
  344. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
  345. with open(file_body_path, "w") as f:
  346. json.dump(data.model_dump(exclude_none=True), f)
  347. log.debug(f"images: {images}")
  348. return images
  349. else:
  350. if form_data.model:
  351. set_model_handler(form_data.model)
  352. data = {
  353. "prompt": form_data.prompt,
  354. "batch_size": form_data.n,
  355. "width": width,
  356. "height": height,
  357. }
  358. if app.state.IMAGE_STEPS != None:
  359. data["steps"] = app.state.IMAGE_STEPS
  360. if form_data.negative_prompt != None:
  361. data["negative_prompt"] = form_data.negative_prompt
  362. r = requests.post(
  363. url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
  364. json=data,
  365. )
  366. res = r.json()
  367. log.debug(f"res: {res}")
  368. images = []
  369. for image in res["images"]:
  370. image_filename = save_b64_image(image)
  371. images.append({"url": f"/cache/image/generations/{image_filename}"})
  372. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
  373. with open(file_body_path, "w") as f:
  374. json.dump({**data, "info": res["info"]}, f)
  375. return images
  376. except Exception as e:
  377. error = e
  378. if r != None:
  379. data = r.json()
  380. if "error" in data:
  381. error = data["error"]["message"]
  382. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))