main.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475
  1. import re
  2. import requests
  3. from fastapi import (
  4. FastAPI,
  5. Request,
  6. Depends,
  7. HTTPException,
  8. status,
  9. UploadFile,
  10. File,
  11. Form,
  12. )
  13. from fastapi.middleware.cors import CORSMiddleware
  14. from faster_whisper import WhisperModel
  15. from constants import ERROR_MESSAGES
  16. from utils.utils import (
  17. get_current_user,
  18. get_admin_user,
  19. )
  20. from apps.images.utils.comfyui import ImageGenerationPayload, comfyui_generate_image
  21. from utils.misc import calculate_sha256
  22. from typing import Optional
  23. from pydantic import BaseModel
  24. from pathlib import Path
  25. import uuid
  26. import base64
  27. import json
  28. import logging
  29. from config import (
  30. SRC_LOG_LEVELS,
  31. CACHE_DIR,
  32. ENABLE_IMAGE_GENERATION,
  33. AUTOMATIC1111_BASE_URL,
  34. COMFYUI_BASE_URL,
  35. IMAGES_OPENAI_API_BASE_URL,
  36. IMAGES_OPENAI_API_KEY,
  37. )
  38. log = logging.getLogger(__name__)
  39. log.setLevel(SRC_LOG_LEVELS["IMAGES"])
  40. IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
  41. IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  42. app = FastAPI()
  43. app.add_middleware(
  44. CORSMiddleware,
  45. allow_origins=["*"],
  46. allow_credentials=True,
  47. allow_methods=["*"],
  48. allow_headers=["*"],
  49. )
  50. app.state.ENGINE = ""
  51. app.state.ENABLED = ENABLE_IMAGE_GENERATION
  52. app.state.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
  53. app.state.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
  54. app.state.MODEL = ""
  55. app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  56. app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
  57. app.state.IMAGE_SIZE = "512x512"
  58. app.state.IMAGE_STEPS = 50
  59. @app.get("/config")
  60. async def get_config(request: Request, user=Depends(get_admin_user)):
  61. return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
  62. class ConfigUpdateForm(BaseModel):
  63. engine: str
  64. enabled: bool
  65. @app.post("/config/update")
  66. async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
  67. app.state.ENGINE = form_data.engine
  68. app.state.ENABLED = form_data.enabled
  69. return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
  70. class EngineUrlUpdateForm(BaseModel):
  71. AUTOMATIC1111_BASE_URL: Optional[str] = None
  72. COMFYUI_BASE_URL: Optional[str] = None
  73. @app.get("/url")
  74. async def get_engine_url(user=Depends(get_admin_user)):
  75. return {
  76. "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
  77. "COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL,
  78. }
  79. @app.post("/url/update")
  80. async def update_engine_url(
  81. form_data: EngineUrlUpdateForm, user=Depends(get_admin_user)
  82. ):
  83. if form_data.AUTOMATIC1111_BASE_URL == None:
  84. app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  85. else:
  86. url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
  87. try:
  88. r = requests.head(url)
  89. app.state.AUTOMATIC1111_BASE_URL = url
  90. except Exception as e:
  91. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  92. if form_data.COMFYUI_BASE_URL == None:
  93. app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
  94. else:
  95. url = form_data.COMFYUI_BASE_URL.strip("/")
  96. try:
  97. r = requests.head(url)
  98. app.state.COMFYUI_BASE_URL = url
  99. except Exception as e:
  100. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  101. return {
  102. "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
  103. "COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL,
  104. "status": True,
  105. }
  106. class OpenAIConfigUpdateForm(BaseModel):
  107. url: str
  108. key: str
  109. @app.get("/openai/config")
  110. async def get_openai_config(user=Depends(get_admin_user)):
  111. return {
  112. "OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
  113. "OPENAI_API_KEY": app.state.OPENAI_API_KEY,
  114. }
  115. @app.post("/openai/config/update")
  116. async def update_openai_config(
  117. form_data: OpenAIConfigUpdateForm, user=Depends(get_admin_user)
  118. ):
  119. if form_data.key == "":
  120. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
  121. app.state.OPENAI_API_BASE_URL = form_data.url
  122. app.state.OPENAI_API_KEY = form_data.key
  123. return {
  124. "status": True,
  125. "OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
  126. "OPENAI_API_KEY": app.state.OPENAI_API_KEY,
  127. }
  128. class ImageSizeUpdateForm(BaseModel):
  129. size: str
  130. @app.get("/size")
  131. async def get_image_size(user=Depends(get_admin_user)):
  132. return {"IMAGE_SIZE": app.state.IMAGE_SIZE}
  133. @app.post("/size/update")
  134. async def update_image_size(
  135. form_data: ImageSizeUpdateForm, user=Depends(get_admin_user)
  136. ):
  137. pattern = r"^\d+x\d+$" # Regular expression pattern
  138. if re.match(pattern, form_data.size):
  139. app.state.IMAGE_SIZE = form_data.size
  140. return {
  141. "IMAGE_SIZE": app.state.IMAGE_SIZE,
  142. "status": True,
  143. }
  144. else:
  145. raise HTTPException(
  146. status_code=400,
  147. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
  148. )
  149. class ImageStepsUpdateForm(BaseModel):
  150. steps: int
  151. @app.get("/steps")
  152. async def get_image_size(user=Depends(get_admin_user)):
  153. return {"IMAGE_STEPS": app.state.IMAGE_STEPS}
  154. @app.post("/steps/update")
  155. async def update_image_size(
  156. form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
  157. ):
  158. if form_data.steps >= 0:
  159. app.state.IMAGE_STEPS = form_data.steps
  160. return {
  161. "IMAGE_STEPS": app.state.IMAGE_STEPS,
  162. "status": True,
  163. }
  164. else:
  165. raise HTTPException(
  166. status_code=400,
  167. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
  168. )
  169. @app.get("/models")
  170. def get_models(user=Depends(get_current_user)):
  171. try:
  172. if app.state.ENGINE == "openai":
  173. return [
  174. {"id": "dall-e-2", "name": "DALL·E 2"},
  175. {"id": "dall-e-3", "name": "DALL·E 3"},
  176. ]
  177. elif app.state.ENGINE == "comfyui":
  178. r = requests.get(url=f"{app.state.COMFYUI_BASE_URL}/object_info")
  179. info = r.json()
  180. return list(
  181. map(
  182. lambda model: {"id": model, "name": model},
  183. info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0],
  184. )
  185. )
  186. else:
  187. r = requests.get(
  188. url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
  189. )
  190. models = r.json()
  191. return list(
  192. map(
  193. lambda model: {"id": model["title"], "name": model["model_name"]},
  194. models,
  195. )
  196. )
  197. except Exception as e:
  198. app.state.ENABLED = False
  199. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  200. @app.get("/models/default")
  201. async def get_default_model(user=Depends(get_admin_user)):
  202. try:
  203. if app.state.ENGINE == "openai":
  204. return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"}
  205. elif app.state.ENGINE == "comfyui":
  206. return {"model": app.state.MODEL if app.state.MODEL else ""}
  207. else:
  208. r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
  209. options = r.json()
  210. return {"model": options["sd_model_checkpoint"]}
  211. except Exception as e:
  212. app.state.ENABLED = False
  213. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  214. class UpdateModelForm(BaseModel):
  215. model: str
  216. def set_model_handler(model: str):
  217. if app.state.ENGINE == "openai":
  218. app.state.MODEL = model
  219. return app.state.MODEL
  220. if app.state.ENGINE == "comfyui":
  221. app.state.MODEL = model
  222. return app.state.MODEL
  223. else:
  224. r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
  225. options = r.json()
  226. if model != options["sd_model_checkpoint"]:
  227. options["sd_model_checkpoint"] = model
  228. r = requests.post(
  229. url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options
  230. )
  231. return options
  232. @app.post("/models/default/update")
  233. def update_default_model(
  234. form_data: UpdateModelForm,
  235. user=Depends(get_current_user),
  236. ):
  237. return set_model_handler(form_data.model)
  238. class GenerateImageForm(BaseModel):
  239. model: Optional[str] = None
  240. prompt: str
  241. n: int = 1
  242. size: Optional[str] = None
  243. negative_prompt: Optional[str] = None
  244. def save_b64_image(b64_str):
  245. image_id = str(uuid.uuid4())
  246. file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
  247. try:
  248. # Split the base64 string to get the actual image data
  249. img_data = base64.b64decode(b64_str)
  250. # Write the image data to a file
  251. with open(file_path, "wb") as f:
  252. f.write(img_data)
  253. return image_id
  254. except Exception as e:
  255. log.error(f"Error saving image: {e}")
  256. return None
  257. def save_url_image(url):
  258. image_id = str(uuid.uuid4())
  259. file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
  260. try:
  261. r = requests.get(url)
  262. r.raise_for_status()
  263. with open(file_path, "wb") as image_file:
  264. image_file.write(r.content)
  265. return image_id
  266. except Exception as e:
  267. log.exception(f"Error saving image: {e}")
  268. return None
  269. @app.post("/generations")
  270. def generate_image(
  271. form_data: GenerateImageForm,
  272. user=Depends(get_current_user),
  273. ):
  274. width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
  275. r = None
  276. try:
  277. if app.state.ENGINE == "openai":
  278. headers = {}
  279. headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
  280. headers["Content-Type"] = "application/json"
  281. data = {
  282. "model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2",
  283. "prompt": form_data.prompt,
  284. "n": form_data.n,
  285. "size": form_data.size if form_data.size else app.state.IMAGE_SIZE,
  286. "response_format": "b64_json",
  287. }
  288. r = requests.post(
  289. url=f"{app.state.OPENAI_API_BASE_URL}/images/generations",
  290. json=data,
  291. headers=headers,
  292. )
  293. r.raise_for_status()
  294. res = r.json()
  295. images = []
  296. for image in res["data"]:
  297. image_id = save_b64_image(image["b64_json"])
  298. images.append({"url": f"/cache/image/generations/{image_id}.png"})
  299. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
  300. with open(file_body_path, "w") as f:
  301. json.dump(data, f)
  302. return images
  303. elif app.state.ENGINE == "comfyui":
  304. data = {
  305. "prompt": form_data.prompt,
  306. "width": width,
  307. "height": height,
  308. "n": form_data.n,
  309. }
  310. if app.state.IMAGE_STEPS != None:
  311. data["steps"] = app.state.IMAGE_STEPS
  312. if form_data.negative_prompt != None:
  313. data["negative_prompt"] = form_data.negative_prompt
  314. data = ImageGenerationPayload(**data)
  315. res = comfyui_generate_image(
  316. app.state.MODEL,
  317. data,
  318. user.id,
  319. app.state.COMFYUI_BASE_URL,
  320. )
  321. log.debug(f"res: {res}")
  322. images = []
  323. for image in res["data"]:
  324. image_id = save_url_image(image["url"])
  325. images.append({"url": f"/cache/image/generations/{image_id}.png"})
  326. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
  327. with open(file_body_path, "w") as f:
  328. json.dump(data.model_dump(exclude_none=True), f)
  329. log.debug(f"images: {images}")
  330. return images
  331. else:
  332. if form_data.model:
  333. set_model_handler(form_data.model)
  334. data = {
  335. "prompt": form_data.prompt,
  336. "batch_size": form_data.n,
  337. "width": width,
  338. "height": height,
  339. }
  340. if app.state.IMAGE_STEPS != None:
  341. data["steps"] = app.state.IMAGE_STEPS
  342. if form_data.negative_prompt != None:
  343. data["negative_prompt"] = form_data.negative_prompt
  344. r = requests.post(
  345. url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
  346. json=data,
  347. )
  348. res = r.json()
  349. log.debug(f"res: {res}")
  350. images = []
  351. for image in res["images"]:
  352. image_id = save_b64_image(image)
  353. images.append({"url": f"/cache/image/generations/{image_id}.png"})
  354. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
  355. with open(file_body_path, "w") as f:
  356. json.dump({**data, "info": res["info"]}, f)
  357. return images
  358. except Exception as e:
  359. error = e
  360. if r != None:
  361. data = r.json()
  362. if "error" in data:
  363. error = data["error"]["message"]
  364. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))