main.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. import re
  2. import requests
  3. from fastapi import (
  4. FastAPI,
  5. Request,
  6. Depends,
  7. HTTPException,
  8. status,
  9. UploadFile,
  10. File,
  11. Form,
  12. )
  13. from fastapi.middleware.cors import CORSMiddleware
  14. from faster_whisper import WhisperModel
  15. from constants import ERROR_MESSAGES
  16. from utils.utils import (
  17. get_current_user,
  18. get_admin_user,
  19. )
  20. from apps.images.utils.comfyui import ImageGenerationPayload, comfyui_generate_image
  21. from utils.misc import calculate_sha256
  22. from typing import Optional
  23. from pydantic import BaseModel
  24. from pathlib import Path
  25. import uuid
  26. import base64
  27. import json
  28. import logging
  29. from config import (
  30. SRC_LOG_LEVELS,
  31. CACHE_DIR,
  32. ENABLE_IMAGE_GENERATION,
  33. AUTOMATIC1111_BASE_URL,
  34. COMFYUI_BASE_URL,
  35. IMAGE_OPENAI_API_BASE_URL,
  36. IMAGE_OPENAI_API_KEY,
  37. )
  38. log = logging.getLogger(__name__)
  39. log.setLevel(SRC_LOG_LEVELS["IMAGES"])
  40. IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
  41. IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  42. app = FastAPI()
  43. app.add_middleware(
  44. CORSMiddleware,
  45. allow_origins=["*"],
  46. allow_credentials=True,
  47. allow_methods=["*"],
  48. allow_headers=["*"],
  49. )
  50. app.state.ENGINE = ""
  51. app.state.ENABLED = ENABLE_IMAGE_GENERATION
  52. app.state.OPENAI_API_BASE_URL = IMAGE_OPENAI_API_BASE_URL
  53. app.state.OPENAI_API_KEY = IMAGE_OPENAI_API_KEY
  54. app.state.MODEL = ""
  55. app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  56. app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
  57. app.state.IMAGE_SIZE = "512x512"
  58. app.state.IMAGE_STEPS = 50
  59. @app.get("/config")
  60. async def get_config(request: Request, user=Depends(get_admin_user)):
  61. return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
  62. class ConfigUpdateForm(BaseModel):
  63. engine: str
  64. enabled: bool
  65. @app.post("/config/update")
  66. async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
  67. app.state.ENGINE = form_data.engine
  68. app.state.ENABLED = form_data.enabled
  69. return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
  70. class EngineUrlUpdateForm(BaseModel):
  71. AUTOMATIC1111_BASE_URL: Optional[str] = None
  72. COMFYUI_BASE_URL: Optional[str] = None
  73. @app.get("/url")
  74. async def get_engine_url(user=Depends(get_admin_user)):
  75. return {
  76. "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
  77. "COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL,
  78. }
  79. @app.post("/url/update")
  80. async def update_engine_url(
  81. form_data: EngineUrlUpdateForm, user=Depends(get_admin_user)
  82. ):
  83. if form_data.AUTOMATIC1111_BASE_URL == None:
  84. app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  85. else:
  86. url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
  87. try:
  88. r = requests.head(url)
  89. app.state.AUTOMATIC1111_BASE_URL = url
  90. except Exception as e:
  91. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  92. if form_data.COMFYUI_BASE_URL == None:
  93. app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
  94. else:
  95. url = form_data.COMFYUI_BASE_URL.strip("/")
  96. try:
  97. r = requests.head(url)
  98. app.state.COMFYUI_BASE_URL = url
  99. except Exception as e:
  100. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  101. return {
  102. "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
  103. "COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL,
  104. "status": True,
  105. }
  106. class OpenAIKeyUpdateForm(BaseModel):
  107. key: str
  108. @app.get("/key")
  109. async def get_openai_key(user=Depends(get_admin_user)):
  110. return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
  111. @app.post("/key/update")
  112. async def update_openai_key(
  113. form_data: OpenAIKeyUpdateForm, user=Depends(get_admin_user)
  114. ):
  115. if form_data.key == "":
  116. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
  117. app.state.OPENAI_API_KEY = form_data.key
  118. return {
  119. "OPENAI_API_KEY": app.state.OPENAI_API_KEY,
  120. "status": True,
  121. }
  122. class ImageSizeUpdateForm(BaseModel):
  123. size: str
  124. @app.get("/size")
  125. async def get_image_size(user=Depends(get_admin_user)):
  126. return {"IMAGE_SIZE": app.state.IMAGE_SIZE}
  127. @app.post("/size/update")
  128. async def update_image_size(
  129. form_data: ImageSizeUpdateForm, user=Depends(get_admin_user)
  130. ):
  131. pattern = r"^\d+x\d+$" # Regular expression pattern
  132. if re.match(pattern, form_data.size):
  133. app.state.IMAGE_SIZE = form_data.size
  134. return {
  135. "IMAGE_SIZE": app.state.IMAGE_SIZE,
  136. "status": True,
  137. }
  138. else:
  139. raise HTTPException(
  140. status_code=400,
  141. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
  142. )
  143. class ImageStepsUpdateForm(BaseModel):
  144. steps: int
  145. @app.get("/steps")
  146. async def get_image_size(user=Depends(get_admin_user)):
  147. return {"IMAGE_STEPS": app.state.IMAGE_STEPS}
  148. @app.post("/steps/update")
  149. async def update_image_size(
  150. form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
  151. ):
  152. if form_data.steps >= 0:
  153. app.state.IMAGE_STEPS = form_data.steps
  154. return {
  155. "IMAGE_STEPS": app.state.IMAGE_STEPS,
  156. "status": True,
  157. }
  158. else:
  159. raise HTTPException(
  160. status_code=400,
  161. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
  162. )
  163. @app.get("/models")
  164. def get_models(user=Depends(get_current_user)):
  165. try:
  166. if app.state.ENGINE == "openai":
  167. return [
  168. {"id": "dall-e-2", "name": "DALL·E 2"},
  169. {"id": "dall-e-3", "name": "DALL·E 3"},
  170. ]
  171. elif app.state.ENGINE == "comfyui":
  172. r = requests.get(url=f"{app.state.COMFYUI_BASE_URL}/object_info")
  173. info = r.json()
  174. return list(
  175. map(
  176. lambda model: {"id": model, "name": model},
  177. info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0],
  178. )
  179. )
  180. else:
  181. r = requests.get(
  182. url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
  183. )
  184. models = r.json()
  185. return list(
  186. map(
  187. lambda model: {"id": model["title"], "name": model["model_name"]},
  188. models,
  189. )
  190. )
  191. except Exception as e:
  192. app.state.ENABLED = False
  193. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  194. @app.get("/models/default")
  195. async def get_default_model(user=Depends(get_admin_user)):
  196. try:
  197. if app.state.ENGINE == "openai":
  198. return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"}
  199. elif app.state.ENGINE == "comfyui":
  200. return {"model": app.state.MODEL if app.state.MODEL else ""}
  201. else:
  202. r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
  203. options = r.json()
  204. return {"model": options["sd_model_checkpoint"]}
  205. except Exception as e:
  206. app.state.ENABLED = False
  207. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  208. class UpdateModelForm(BaseModel):
  209. model: str
  210. def set_model_handler(model: str):
  211. if app.state.ENGINE == "openai":
  212. app.state.MODEL = model
  213. return app.state.MODEL
  214. if app.state.ENGINE == "comfyui":
  215. app.state.MODEL = model
  216. return app.state.MODEL
  217. else:
  218. r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
  219. options = r.json()
  220. if model != options["sd_model_checkpoint"]:
  221. options["sd_model_checkpoint"] = model
  222. r = requests.post(
  223. url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options
  224. )
  225. return options
  226. @app.post("/models/default/update")
  227. def update_default_model(
  228. form_data: UpdateModelForm,
  229. user=Depends(get_current_user),
  230. ):
  231. return set_model_handler(form_data.model)
  232. class GenerateImageForm(BaseModel):
  233. model: Optional[str] = None
  234. prompt: str
  235. n: int = 1
  236. size: Optional[str] = None
  237. negative_prompt: Optional[str] = None
  238. def save_b64_image(b64_str):
  239. image_id = str(uuid.uuid4())
  240. file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
  241. try:
  242. # Split the base64 string to get the actual image data
  243. img_data = base64.b64decode(b64_str)
  244. # Write the image data to a file
  245. with open(file_path, "wb") as f:
  246. f.write(img_data)
  247. return image_id
  248. except Exception as e:
  249. log.error(f"Error saving image: {e}")
  250. return None
  251. def save_url_image(url):
  252. image_id = str(uuid.uuid4())
  253. file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
  254. try:
  255. r = requests.get(url)
  256. r.raise_for_status()
  257. with open(file_path, "wb") as image_file:
  258. image_file.write(r.content)
  259. return image_id
  260. except Exception as e:
  261. log.exception(f"Error saving image: {e}")
  262. return None
  263. @app.post("/generations")
  264. def generate_image(
  265. form_data: GenerateImageForm,
  266. user=Depends(get_current_user),
  267. ):
  268. width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
  269. r = None
  270. try:
  271. if app.state.ENGINE == "openai":
  272. headers = {}
  273. headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
  274. headers["Content-Type"] = "application/json"
  275. data = {
  276. "model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2",
  277. "prompt": form_data.prompt,
  278. "n": form_data.n,
  279. "size": form_data.size if form_data.size else app.state.IMAGE_SIZE,
  280. "response_format": "b64_json",
  281. }
  282. r = requests.post(
  283. url=f"{app.state.OPENAI_API_BASE_URL}/images/generations",
  284. json=data,
  285. headers=headers,
  286. )
  287. r.raise_for_status()
  288. res = r.json()
  289. images = []
  290. for image in res["data"]:
  291. image_id = save_b64_image(image["b64_json"])
  292. images.append({"url": f"/cache/image/generations/{image_id}.png"})
  293. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
  294. with open(file_body_path, "w") as f:
  295. json.dump(data, f)
  296. return images
  297. elif app.state.ENGINE == "comfyui":
  298. data = {
  299. "prompt": form_data.prompt,
  300. "width": width,
  301. "height": height,
  302. "n": form_data.n,
  303. }
  304. if app.state.IMAGE_STEPS != None:
  305. data["steps"] = app.state.IMAGE_STEPS
  306. if form_data.negative_prompt != None:
  307. data["negative_prompt"] = form_data.negative_prompt
  308. data = ImageGenerationPayload(**data)
  309. res = comfyui_generate_image(
  310. app.state.MODEL,
  311. data,
  312. user.id,
  313. app.state.COMFYUI_BASE_URL,
  314. )
  315. log.debug(f"res: {res}")
  316. images = []
  317. for image in res["data"]:
  318. image_id = save_url_image(image["url"])
  319. images.append({"url": f"/cache/image/generations/{image_id}.png"})
  320. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
  321. with open(file_body_path, "w") as f:
  322. json.dump(data.model_dump(exclude_none=True), f)
  323. log.debug(f"images: {images}")
  324. return images
  325. else:
  326. if form_data.model:
  327. set_model_handler(form_data.model)
  328. data = {
  329. "prompt": form_data.prompt,
  330. "batch_size": form_data.n,
  331. "width": width,
  332. "height": height,
  333. }
  334. if app.state.IMAGE_STEPS != None:
  335. data["steps"] = app.state.IMAGE_STEPS
  336. if form_data.negative_prompt != None:
  337. data["negative_prompt"] = form_data.negative_prompt
  338. r = requests.post(
  339. url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
  340. json=data,
  341. )
  342. res = r.json()
  343. log.debug(f"res: {res}")
  344. images = []
  345. for image in res["images"]:
  346. image_id = save_b64_image(image)
  347. images.append({"url": f"/cache/image/generations/{image_id}.png"})
  348. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
  349. with open(file_body_path, "w") as f:
  350. json.dump({**data, "info": res["info"]}, f)
  351. return images
  352. except Exception as e:
  353. error = e
  354. if r != None:
  355. data = r.json()
  356. if "error" in data:
  357. error = data["error"]["message"]
  358. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))