main.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479
  1. import re
  2. import requests
  3. from fastapi import (
  4. FastAPI,
  5. Request,
  6. Depends,
  7. HTTPException,
  8. status,
  9. UploadFile,
  10. File,
  11. Form,
  12. )
  13. from fastapi.middleware.cors import CORSMiddleware
  14. from faster_whisper import WhisperModel
  15. from constants import ERROR_MESSAGES
  16. from utils.utils import (
  17. get_current_user,
  18. get_admin_user,
  19. )
  20. from apps.images.utils.comfyui import ImageGenerationPayload, comfyui_generate_image
  21. from utils.misc import calculate_sha256
  22. from typing import Optional
  23. from pydantic import BaseModel
  24. from pathlib import Path
  25. import uuid
  26. import base64
  27. import json
  28. import logging
  29. from config import (
  30. SRC_LOG_LEVELS,
  31. CACHE_DIR,
  32. IMAGES_GENERATION_ENGINE,
  33. ENABLE_IMAGE_GENERATION,
  34. AUTOMATIC1111_BASE_URL,
  35. COMFYUI_BASE_URL,
  36. IMAGES_OPENAI_API_BASE_URL,
  37. IMAGES_OPENAI_API_KEY,
  38. IMAGES_MODEL,
  39. IMAGE_SIZE,
  40. IMAGE_STEPS,
  41. )
  42. log = logging.getLogger(__name__)
  43. log.setLevel(SRC_LOG_LEVELS["IMAGES"])
  44. IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
  45. IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  46. app = FastAPI()
  47. app.add_middleware(
  48. CORSMiddleware,
  49. allow_origins=["*"],
  50. allow_credentials=True,
  51. allow_methods=["*"],
  52. allow_headers=["*"],
  53. )
  54. app.state.ENGINE = IMAGES_GENERATION_ENGINE
  55. app.state.ENABLED = ENABLE_IMAGE_GENERATION
  56. app.state.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
  57. app.state.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
  58. app.state.MODEL = IMAGES_MODEL
  59. app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  60. app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
  61. app.state.IMAGE_SIZE = IMAGE_SIZE
  62. app.state.IMAGE_STEPS = IMAGE_STEPS
  63. @app.get("/config")
  64. async def get_config(request: Request, user=Depends(get_admin_user)):
  65. return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
  66. class ConfigUpdateForm(BaseModel):
  67. engine: str
  68. enabled: bool
  69. @app.post("/config/update")
  70. async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
  71. app.state.ENGINE = form_data.engine
  72. app.state.ENABLED = form_data.enabled
  73. return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
  74. class EngineUrlUpdateForm(BaseModel):
  75. AUTOMATIC1111_BASE_URL: Optional[str] = None
  76. COMFYUI_BASE_URL: Optional[str] = None
  77. @app.get("/url")
  78. async def get_engine_url(user=Depends(get_admin_user)):
  79. return {
  80. "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
  81. "COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL,
  82. }
  83. @app.post("/url/update")
  84. async def update_engine_url(
  85. form_data: EngineUrlUpdateForm, user=Depends(get_admin_user)
  86. ):
  87. if form_data.AUTOMATIC1111_BASE_URL == None:
  88. app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  89. else:
  90. url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
  91. try:
  92. r = requests.head(url)
  93. app.state.AUTOMATIC1111_BASE_URL = url
  94. except Exception as e:
  95. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  96. if form_data.COMFYUI_BASE_URL == None:
  97. app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
  98. else:
  99. url = form_data.COMFYUI_BASE_URL.strip("/")
  100. try:
  101. r = requests.head(url)
  102. app.state.COMFYUI_BASE_URL = url
  103. except Exception as e:
  104. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  105. return {
  106. "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
  107. "COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL,
  108. "status": True,
  109. }
  110. class OpenAIConfigUpdateForm(BaseModel):
  111. url: str
  112. key: str
  113. @app.get("/openai/config")
  114. async def get_openai_config(user=Depends(get_admin_user)):
  115. return {
  116. "OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
  117. "OPENAI_API_KEY": app.state.OPENAI_API_KEY,
  118. }
  119. @app.post("/openai/config/update")
  120. async def update_openai_config(
  121. form_data: OpenAIConfigUpdateForm, user=Depends(get_admin_user)
  122. ):
  123. if form_data.key == "":
  124. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
  125. app.state.OPENAI_API_BASE_URL = form_data.url
  126. app.state.OPENAI_API_KEY = form_data.key
  127. return {
  128. "status": True,
  129. "OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
  130. "OPENAI_API_KEY": app.state.OPENAI_API_KEY,
  131. }
  132. class ImageSizeUpdateForm(BaseModel):
  133. size: str
  134. @app.get("/size")
  135. async def get_image_size(user=Depends(get_admin_user)):
  136. return {"IMAGE_SIZE": app.state.IMAGE_SIZE}
  137. @app.post("/size/update")
  138. async def update_image_size(
  139. form_data: ImageSizeUpdateForm, user=Depends(get_admin_user)
  140. ):
  141. pattern = r"^\d+x\d+$" # Regular expression pattern
  142. if re.match(pattern, form_data.size):
  143. app.state.IMAGE_SIZE = form_data.size
  144. return {
  145. "IMAGE_SIZE": app.state.IMAGE_SIZE,
  146. "status": True,
  147. }
  148. else:
  149. raise HTTPException(
  150. status_code=400,
  151. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
  152. )
  153. class ImageStepsUpdateForm(BaseModel):
  154. steps: int
  155. @app.get("/steps")
  156. async def get_image_size(user=Depends(get_admin_user)):
  157. return {"IMAGE_STEPS": app.state.IMAGE_STEPS}
  158. @app.post("/steps/update")
  159. async def update_image_size(
  160. form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
  161. ):
  162. if form_data.steps >= 0:
  163. app.state.IMAGE_STEPS = form_data.steps
  164. return {
  165. "IMAGE_STEPS": app.state.IMAGE_STEPS,
  166. "status": True,
  167. }
  168. else:
  169. raise HTTPException(
  170. status_code=400,
  171. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
  172. )
  173. @app.get("/models")
  174. def get_models(user=Depends(get_current_user)):
  175. try:
  176. if app.state.ENGINE == "openai":
  177. return [
  178. {"id": "dall-e-2", "name": "DALL·E 2"},
  179. {"id": "dall-e-3", "name": "DALL·E 3"},
  180. ]
  181. elif app.state.ENGINE == "comfyui":
  182. r = requests.get(url=f"{app.state.COMFYUI_BASE_URL}/object_info")
  183. info = r.json()
  184. return list(
  185. map(
  186. lambda model: {"id": model, "name": model},
  187. info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0],
  188. )
  189. )
  190. else:
  191. r = requests.get(
  192. url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
  193. )
  194. models = r.json()
  195. return list(
  196. map(
  197. lambda model: {"id": model["title"], "name": model["model_name"]},
  198. models,
  199. )
  200. )
  201. except Exception as e:
  202. app.state.ENABLED = False
  203. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  204. @app.get("/models/default")
  205. async def get_default_model(user=Depends(get_admin_user)):
  206. try:
  207. if app.state.ENGINE == "openai":
  208. return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"}
  209. elif app.state.ENGINE == "comfyui":
  210. return {"model": app.state.MODEL if app.state.MODEL else ""}
  211. else:
  212. r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
  213. options = r.json()
  214. return {"model": options["sd_model_checkpoint"]}
  215. except Exception as e:
  216. app.state.ENABLED = False
  217. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  218. class UpdateModelForm(BaseModel):
  219. model: str
  220. def set_model_handler(model: str):
  221. if app.state.ENGINE == "openai":
  222. app.state.MODEL = model
  223. return app.state.MODEL
  224. if app.state.ENGINE == "comfyui":
  225. app.state.MODEL = model
  226. return app.state.MODEL
  227. else:
  228. r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
  229. options = r.json()
  230. if model != options["sd_model_checkpoint"]:
  231. options["sd_model_checkpoint"] = model
  232. r = requests.post(
  233. url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options
  234. )
  235. return options
  236. @app.post("/models/default/update")
  237. def update_default_model(
  238. form_data: UpdateModelForm,
  239. user=Depends(get_current_user),
  240. ):
  241. return set_model_handler(form_data.model)
  242. class GenerateImageForm(BaseModel):
  243. model: Optional[str] = None
  244. prompt: str
  245. n: int = 1
  246. size: Optional[str] = None
  247. negative_prompt: Optional[str] = None
  248. def save_b64_image(b64_str):
  249. image_id = str(uuid.uuid4())
  250. file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
  251. try:
  252. # Split the base64 string to get the actual image data
  253. img_data = base64.b64decode(b64_str)
  254. # Write the image data to a file
  255. with open(file_path, "wb") as f:
  256. f.write(img_data)
  257. return image_id
  258. except Exception as e:
  259. log.error(f"Error saving image: {e}")
  260. return None
  261. def save_url_image(url):
  262. image_id = str(uuid.uuid4())
  263. file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
  264. try:
  265. r = requests.get(url)
  266. r.raise_for_status()
  267. with open(file_path, "wb") as image_file:
  268. image_file.write(r.content)
  269. return image_id
  270. except Exception as e:
  271. log.exception(f"Error saving image: {e}")
  272. return None
  273. @app.post("/generations")
  274. def generate_image(
  275. form_data: GenerateImageForm,
  276. user=Depends(get_current_user),
  277. ):
  278. width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
  279. r = None
  280. try:
  281. if app.state.ENGINE == "openai":
  282. headers = {}
  283. headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
  284. headers["Content-Type"] = "application/json"
  285. data = {
  286. "model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2",
  287. "prompt": form_data.prompt,
  288. "n": form_data.n,
  289. "size": form_data.size if form_data.size else app.state.IMAGE_SIZE,
  290. "response_format": "b64_json",
  291. }
  292. r = requests.post(
  293. url=f"{app.state.OPENAI_API_BASE_URL}/images/generations",
  294. json=data,
  295. headers=headers,
  296. )
  297. r.raise_for_status()
  298. res = r.json()
  299. images = []
  300. for image in res["data"]:
  301. image_id = save_b64_image(image["b64_json"])
  302. images.append({"url": f"/cache/image/generations/{image_id}.png"})
  303. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
  304. with open(file_body_path, "w") as f:
  305. json.dump(data, f)
  306. return images
  307. elif app.state.ENGINE == "comfyui":
  308. data = {
  309. "prompt": form_data.prompt,
  310. "width": width,
  311. "height": height,
  312. "n": form_data.n,
  313. }
  314. if app.state.IMAGE_STEPS != None:
  315. data["steps"] = app.state.IMAGE_STEPS
  316. if form_data.negative_prompt != None:
  317. data["negative_prompt"] = form_data.negative_prompt
  318. data = ImageGenerationPayload(**data)
  319. res = comfyui_generate_image(
  320. app.state.MODEL,
  321. data,
  322. user.id,
  323. app.state.COMFYUI_BASE_URL,
  324. )
  325. log.debug(f"res: {res}")
  326. images = []
  327. for image in res["data"]:
  328. image_id = save_url_image(image["url"])
  329. images.append({"url": f"/cache/image/generations/{image_id}.png"})
  330. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
  331. with open(file_body_path, "w") as f:
  332. json.dump(data.model_dump(exclude_none=True), f)
  333. log.debug(f"images: {images}")
  334. return images
  335. else:
  336. if form_data.model:
  337. set_model_handler(form_data.model)
  338. data = {
  339. "prompt": form_data.prompt,
  340. "batch_size": form_data.n,
  341. "width": width,
  342. "height": height,
  343. }
  344. if app.state.IMAGE_STEPS != None:
  345. data["steps"] = app.state.IMAGE_STEPS
  346. if form_data.negative_prompt != None:
  347. data["negative_prompt"] = form_data.negative_prompt
  348. r = requests.post(
  349. url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
  350. json=data,
  351. )
  352. res = r.json()
  353. log.debug(f"res: {res}")
  354. images = []
  355. for image in res["images"]:
  356. image_id = save_b64_image(image)
  357. images.append({"url": f"/cache/image/generations/{image_id}.png"})
  358. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
  359. with open(file_body_path, "w") as f:
  360. json.dump({**data, "info": res["info"]}, f)
  361. return images
  362. except Exception as e:
  363. error = e
  364. if r != None:
  365. data = r.json()
  366. if "error" in data:
  367. error = data["error"]["message"]
  368. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))