main.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465
  1. import re
  2. import requests
  3. from fastapi import (
  4. FastAPI,
  5. Request,
  6. Depends,
  7. HTTPException,
  8. status,
  9. UploadFile,
  10. File,
  11. Form,
  12. )
  13. from fastapi.middleware.cors import CORSMiddleware
  14. from faster_whisper import WhisperModel
  15. from constants import ERROR_MESSAGES
  16. from utils.utils import (
  17. get_current_user,
  18. get_admin_user,
  19. )
  20. from apps.images.utils.comfyui import ImageGenerationPayload, comfyui_generate_image
  21. from utils.misc import calculate_sha256
  22. from typing import Optional
  23. from pydantic import BaseModel
  24. from pathlib import Path
  25. import uuid
  26. import base64
  27. import json
  28. import logging
  29. from config import (
  30. SRC_LOG_LEVELS,
  31. CACHE_DIR,
  32. ENABLE_IMAGE_GENERATION,
  33. AUTOMATIC1111_BASE_URL,
  34. COMFYUI_BASE_URL,
  35. )
  36. log = logging.getLogger(__name__)
  37. log.setLevel(SRC_LOG_LEVELS["IMAGES"])
  38. IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
  39. IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  40. app = FastAPI()
  41. app.add_middleware(
  42. CORSMiddleware,
  43. allow_origins=["*"],
  44. allow_credentials=True,
  45. allow_methods=["*"],
  46. allow_headers=["*"],
  47. )
  48. app.state.ENGINE = ""
  49. app.state.ENABLED = ENABLE_IMAGE_GENERATION
  50. app.state.OPENAI_API_KEY = ""
  51. app.state.MODEL = ""
  52. app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  53. app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
  54. app.state.IMAGE_SIZE = "512x512"
  55. app.state.IMAGE_STEPS = 50
  56. @app.get("/config")
  57. async def get_config(request: Request, user=Depends(get_admin_user)):
  58. return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
  59. class ConfigUpdateForm(BaseModel):
  60. engine: str
  61. enabled: bool
  62. @app.post("/config/update")
  63. async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
  64. app.state.ENGINE = form_data.engine
  65. app.state.ENABLED = form_data.enabled
  66. return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
  67. class EngineUrlUpdateForm(BaseModel):
  68. AUTOMATIC1111_BASE_URL: Optional[str] = None
  69. COMFYUI_BASE_URL: Optional[str] = None
  70. @app.get("/url")
  71. async def get_engine_url(user=Depends(get_admin_user)):
  72. return {
  73. "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
  74. "COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL,
  75. }
  76. @app.post("/url/update")
  77. async def update_engine_url(
  78. form_data: EngineUrlUpdateForm, user=Depends(get_admin_user)
  79. ):
  80. if form_data.AUTOMATIC1111_BASE_URL == None:
  81. app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  82. else:
  83. url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
  84. try:
  85. r = requests.head(url)
  86. app.state.AUTOMATIC1111_BASE_URL = url
  87. except Exception as e:
  88. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  89. if form_data.COMFYUI_BASE_URL == None:
  90. app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
  91. else:
  92. url = form_data.COMFYUI_BASE_URL.strip("/")
  93. try:
  94. r = requests.head(url)
  95. app.state.COMFYUI_BASE_URL = url
  96. except Exception as e:
  97. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  98. return {
  99. "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
  100. "COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL,
  101. "status": True,
  102. }
  103. class OpenAIKeyUpdateForm(BaseModel):
  104. key: str
  105. @app.get("/key")
  106. async def get_openai_key(user=Depends(get_admin_user)):
  107. return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
  108. @app.post("/key/update")
  109. async def update_openai_key(
  110. form_data: OpenAIKeyUpdateForm, user=Depends(get_admin_user)
  111. ):
  112. if form_data.key == "":
  113. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
  114. app.state.OPENAI_API_KEY = form_data.key
  115. return {
  116. "OPENAI_API_KEY": app.state.OPENAI_API_KEY,
  117. "status": True,
  118. }
  119. class ImageSizeUpdateForm(BaseModel):
  120. size: str
  121. @app.get("/size")
  122. async def get_image_size(user=Depends(get_admin_user)):
  123. return {"IMAGE_SIZE": app.state.IMAGE_SIZE}
  124. @app.post("/size/update")
  125. async def update_image_size(
  126. form_data: ImageSizeUpdateForm, user=Depends(get_admin_user)
  127. ):
  128. pattern = r"^\d+x\d+$" # Regular expression pattern
  129. if re.match(pattern, form_data.size):
  130. app.state.IMAGE_SIZE = form_data.size
  131. return {
  132. "IMAGE_SIZE": app.state.IMAGE_SIZE,
  133. "status": True,
  134. }
  135. else:
  136. raise HTTPException(
  137. status_code=400,
  138. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
  139. )
  140. class ImageStepsUpdateForm(BaseModel):
  141. steps: int
  142. @app.get("/steps")
  143. async def get_image_size(user=Depends(get_admin_user)):
  144. return {"IMAGE_STEPS": app.state.IMAGE_STEPS}
  145. @app.post("/steps/update")
  146. async def update_image_size(
  147. form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
  148. ):
  149. if form_data.steps >= 0:
  150. app.state.IMAGE_STEPS = form_data.steps
  151. return {
  152. "IMAGE_STEPS": app.state.IMAGE_STEPS,
  153. "status": True,
  154. }
  155. else:
  156. raise HTTPException(
  157. status_code=400,
  158. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
  159. )
  160. @app.get("/models")
  161. def get_models(user=Depends(get_current_user)):
  162. try:
  163. if app.state.ENGINE == "openai":
  164. return [
  165. {"id": "dall-e-2", "name": "DALL·E 2"},
  166. {"id": "dall-e-3", "name": "DALL·E 3"},
  167. ]
  168. elif app.state.ENGINE == "comfyui":
  169. r = requests.get(url=f"{app.state.COMFYUI_BASE_URL}/object_info")
  170. info = r.json()
  171. return list(
  172. map(
  173. lambda model: {"id": model, "name": model},
  174. info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0],
  175. )
  176. )
  177. else:
  178. r = requests.get(
  179. url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
  180. )
  181. models = r.json()
  182. return list(
  183. map(
  184. lambda model: {"id": model["title"], "name": model["model_name"]},
  185. models,
  186. )
  187. )
  188. except Exception as e:
  189. app.state.ENABLED = False
  190. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  191. @app.get("/models/default")
  192. async def get_default_model(user=Depends(get_admin_user)):
  193. try:
  194. if app.state.ENGINE == "openai":
  195. return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"}
  196. elif app.state.ENGINE == "comfyui":
  197. return {"model": app.state.MODEL if app.state.MODEL else ""}
  198. else:
  199. r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
  200. options = r.json()
  201. return {"model": options["sd_model_checkpoint"]}
  202. except Exception as e:
  203. app.state.ENABLED = False
  204. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  205. class UpdateModelForm(BaseModel):
  206. model: str
  207. def set_model_handler(model: str):
  208. if app.state.ENGINE == "openai":
  209. app.state.MODEL = model
  210. return app.state.MODEL
  211. if app.state.ENGINE == "comfyui":
  212. app.state.MODEL = model
  213. return app.state.MODEL
  214. else:
  215. r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
  216. options = r.json()
  217. if model != options["sd_model_checkpoint"]:
  218. options["sd_model_checkpoint"] = model
  219. r = requests.post(
  220. url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options
  221. )
  222. return options
  223. @app.post("/models/default/update")
  224. def update_default_model(
  225. form_data: UpdateModelForm,
  226. user=Depends(get_current_user),
  227. ):
  228. return set_model_handler(form_data.model)
  229. class GenerateImageForm(BaseModel):
  230. model: Optional[str] = None
  231. prompt: str
  232. n: int = 1
  233. size: Optional[str] = None
  234. negative_prompt: Optional[str] = None
  235. def save_b64_image(b64_str):
  236. image_id = str(uuid.uuid4())
  237. file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
  238. try:
  239. # Split the base64 string to get the actual image data
  240. img_data = base64.b64decode(b64_str)
  241. # Write the image data to a file
  242. with open(file_path, "wb") as f:
  243. f.write(img_data)
  244. return image_id
  245. except Exception as e:
  246. log.error(f"Error saving image: {e}")
  247. return None
  248. def save_url_image(url):
  249. image_id = str(uuid.uuid4())
  250. file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
  251. try:
  252. r = requests.get(url)
  253. r.raise_for_status()
  254. with open(file_path, "wb") as image_file:
  255. image_file.write(r.content)
  256. return image_id
  257. except Exception as e:
  258. log.exception(f"Error saving image: {e}")
  259. return None
  260. @app.post("/generations")
  261. def generate_image(
  262. form_data: GenerateImageForm,
  263. user=Depends(get_current_user),
  264. ):
  265. width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
  266. r = None
  267. try:
  268. if app.state.ENGINE == "openai":
  269. headers = {}
  270. headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
  271. headers["Content-Type"] = "application/json"
  272. data = {
  273. "model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2",
  274. "prompt": form_data.prompt,
  275. "n": form_data.n,
  276. "size": form_data.size if form_data.size else app.state.IMAGE_SIZE,
  277. "response_format": "b64_json",
  278. }
  279. r = requests.post(
  280. url=f"https://api.openai.com/v1/images/generations",
  281. json=data,
  282. headers=headers,
  283. )
  284. r.raise_for_status()
  285. res = r.json()
  286. images = []
  287. for image in res["data"]:
  288. image_id = save_b64_image(image["b64_json"])
  289. images.append({"url": f"/cache/image/generations/{image_id}.png"})
  290. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
  291. with open(file_body_path, "w") as f:
  292. json.dump(data, f)
  293. return images
  294. elif app.state.ENGINE == "comfyui":
  295. data = {
  296. "prompt": form_data.prompt,
  297. "width": width,
  298. "height": height,
  299. "n": form_data.n,
  300. }
  301. if app.state.IMAGE_STEPS != None:
  302. data["steps"] = app.state.IMAGE_STEPS
  303. if form_data.negative_prompt != None:
  304. data["negative_prompt"] = form_data.negative_prompt
  305. data = ImageGenerationPayload(**data)
  306. res = comfyui_generate_image(
  307. app.state.MODEL,
  308. data,
  309. user.id,
  310. app.state.COMFYUI_BASE_URL,
  311. )
  312. log.debug(f"res: {res}")
  313. images = []
  314. for image in res["data"]:
  315. image_id = save_url_image(image["url"])
  316. images.append({"url": f"/cache/image/generations/{image_id}.png"})
  317. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
  318. with open(file_body_path, "w") as f:
  319. json.dump(data.model_dump(exclude_none=True), f)
  320. log.debug(f"images: {images}")
  321. return images
  322. else:
  323. if form_data.model:
  324. set_model_handler(form_data.model)
  325. data = {
  326. "prompt": form_data.prompt,
  327. "batch_size": form_data.n,
  328. "width": width,
  329. "height": height,
  330. }
  331. if app.state.IMAGE_STEPS != None:
  332. data["steps"] = app.state.IMAGE_STEPS
  333. if form_data.negative_prompt != None:
  334. data["negative_prompt"] = form_data.negative_prompt
  335. r = requests.post(
  336. url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
  337. json=data,
  338. )
  339. res = r.json()
  340. log.debug(f"res: {res}")
  341. images = []
  342. for image in res["images"]:
  343. image_id = save_b64_image(image)
  344. images.append({"url": f"/cache/image/generations/{image_id}.png"})
  345. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
  346. with open(file_body_path, "w") as f:
  347. json.dump({**data, "info": res["info"]}, f)
  348. return images
  349. except Exception as e:
  350. error = e
  351. if r != None:
  352. data = r.json()
  353. if "error" in data:
  354. error = data["error"]["message"]
  355. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))