main.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. import re
  2. import requests
  3. from fastapi import (
  4. FastAPI,
  5. Request,
  6. Depends,
  7. HTTPException,
  8. status,
  9. UploadFile,
  10. File,
  11. Form,
  12. )
  13. from fastapi.middleware.cors import CORSMiddleware
  14. from faster_whisper import WhisperModel
  15. from constants import ERROR_MESSAGES
  16. from utils.utils import (
  17. get_current_user,
  18. get_admin_user,
  19. )
  20. from apps.images.utils.comfyui import ImageGenerationPayload, comfyui_generate_image
  21. from utils.misc import calculate_sha256
  22. from typing import Optional
  23. from pydantic import BaseModel
  24. from pathlib import Path
  25. import uuid
  26. import base64
  27. import json
  28. from config import CACHE_DIR, AUTOMATIC1111_BASE_URL, COMFYUI_BASE_URL
  29. IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
  30. IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  31. app = FastAPI()
  32. app.add_middleware(
  33. CORSMiddleware,
  34. allow_origins=["*"],
  35. allow_credentials=True,
  36. allow_methods=["*"],
  37. allow_headers=["*"],
  38. )
  39. app.state.ENGINE = ""
  40. app.state.ENABLED = False
  41. app.state.OPENAI_API_KEY = ""
  42. app.state.MODEL = ""
  43. app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  44. app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
  45. app.state.IMAGE_SIZE = "512x512"
  46. app.state.IMAGE_STEPS = 50
  47. @app.get("/config")
  48. async def get_config(request: Request, user=Depends(get_admin_user)):
  49. return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
  50. class ConfigUpdateForm(BaseModel):
  51. engine: str
  52. enabled: bool
  53. @app.post("/config/update")
  54. async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
  55. app.state.ENGINE = form_data.engine
  56. app.state.ENABLED = form_data.enabled
  57. return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
  58. class EngineUrlUpdateForm(BaseModel):
  59. AUTOMATIC1111_BASE_URL: Optional[str] = None
  60. COMFYUI_BASE_URL: Optional[str] = None
  61. @app.get("/url")
  62. async def get_engine_url(user=Depends(get_admin_user)):
  63. return {
  64. "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
  65. "COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL,
  66. }
  67. @app.post("/url/update")
  68. async def update_engine_url(
  69. form_data: EngineUrlUpdateForm, user=Depends(get_admin_user)
  70. ):
  71. if form_data.AUTOMATIC1111_BASE_URL == None:
  72. app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  73. else:
  74. url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
  75. try:
  76. r = requests.head(url)
  77. app.state.AUTOMATIC1111_BASE_URL = url
  78. except Exception as e:
  79. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  80. if form_data.COMFYUI_BASE_URL == None:
  81. app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
  82. else:
  83. url = form_data.COMFYUI_BASE_URL.strip("/")
  84. try:
  85. r = requests.head(url)
  86. app.state.COMFYUI_BASE_URL = url
  87. except Exception as e:
  88. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  89. return {
  90. "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
  91. "COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL,
  92. "status": True,
  93. }
  94. class OpenAIKeyUpdateForm(BaseModel):
  95. key: str
  96. @app.get("/key")
  97. async def get_openai_key(user=Depends(get_admin_user)):
  98. return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
  99. @app.post("/key/update")
  100. async def update_openai_key(
  101. form_data: OpenAIKeyUpdateForm, user=Depends(get_admin_user)
  102. ):
  103. if form_data.key == "":
  104. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
  105. app.state.OPENAI_API_KEY = form_data.key
  106. return {
  107. "OPENAI_API_KEY": app.state.OPENAI_API_KEY,
  108. "status": True,
  109. }
  110. class ImageSizeUpdateForm(BaseModel):
  111. size: str
  112. @app.get("/size")
  113. async def get_image_size(user=Depends(get_admin_user)):
  114. return {"IMAGE_SIZE": app.state.IMAGE_SIZE}
  115. @app.post("/size/update")
  116. async def update_image_size(
  117. form_data: ImageSizeUpdateForm, user=Depends(get_admin_user)
  118. ):
  119. pattern = r"^\d+x\d+$" # Regular expression pattern
  120. if re.match(pattern, form_data.size):
  121. app.state.IMAGE_SIZE = form_data.size
  122. return {
  123. "IMAGE_SIZE": app.state.IMAGE_SIZE,
  124. "status": True,
  125. }
  126. else:
  127. raise HTTPException(
  128. status_code=400,
  129. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
  130. )
  131. class ImageStepsUpdateForm(BaseModel):
  132. steps: int
  133. @app.get("/steps")
  134. async def get_image_size(user=Depends(get_admin_user)):
  135. return {"IMAGE_STEPS": app.state.IMAGE_STEPS}
  136. @app.post("/steps/update")
  137. async def update_image_size(
  138. form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
  139. ):
  140. if form_data.steps >= 0:
  141. app.state.IMAGE_STEPS = form_data.steps
  142. return {
  143. "IMAGE_STEPS": app.state.IMAGE_STEPS,
  144. "status": True,
  145. }
  146. else:
  147. raise HTTPException(
  148. status_code=400,
  149. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
  150. )
  151. @app.get("/models")
  152. def get_models(user=Depends(get_current_user)):
  153. try:
  154. if app.state.ENGINE == "openai":
  155. return [
  156. {"id": "dall-e-2", "name": "DALL·E 2"},
  157. {"id": "dall-e-3", "name": "DALL·E 3"},
  158. ]
  159. elif app.state.ENGINE == "comfyui":
  160. r = requests.get(url=f"{app.state.COMFYUI_BASE_URL}/object_info")
  161. info = r.json()
  162. return list(
  163. map(
  164. lambda model: {"id": model, "name": model},
  165. info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0],
  166. )
  167. )
  168. else:
  169. r = requests.get(
  170. url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
  171. )
  172. models = r.json()
  173. return list(
  174. map(
  175. lambda model: {"id": model["title"], "name": model["model_name"]},
  176. models,
  177. )
  178. )
  179. except Exception as e:
  180. app.state.ENABLED = False
  181. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  182. @app.get("/models/default")
  183. async def get_default_model(user=Depends(get_admin_user)):
  184. try:
  185. if app.state.ENGINE == "openai":
  186. return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"}
  187. elif app.state.ENGINE == "comfyui":
  188. return {"model": app.state.MODEL if app.state.MODEL else ""}
  189. else:
  190. r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
  191. options = r.json()
  192. return {"model": options["sd_model_checkpoint"]}
  193. except Exception as e:
  194. app.state.ENABLED = False
  195. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  196. class UpdateModelForm(BaseModel):
  197. model: str
  198. def set_model_handler(model: str):
  199. if app.state.ENGINE == "openai":
  200. app.state.MODEL = model
  201. return app.state.MODEL
  202. if app.state.ENGINE == "comfyui":
  203. app.state.MODEL = model
  204. return app.state.MODEL
  205. else:
  206. r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
  207. options = r.json()
  208. if model != options["sd_model_checkpoint"]:
  209. options["sd_model_checkpoint"] = model
  210. r = requests.post(
  211. url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options
  212. )
  213. return options
  214. @app.post("/models/default/update")
  215. def update_default_model(
  216. form_data: UpdateModelForm,
  217. user=Depends(get_current_user),
  218. ):
  219. return set_model_handler(form_data.model)
  220. class GenerateImageForm(BaseModel):
  221. model: Optional[str] = None
  222. prompt: str
  223. n: int = 1
  224. size: Optional[str] = None
  225. negative_prompt: Optional[str] = None
  226. def save_b64_image(b64_str):
  227. image_id = str(uuid.uuid4())
  228. file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
  229. try:
  230. # Split the base64 string to get the actual image data
  231. img_data = base64.b64decode(b64_str)
  232. # Write the image data to a file
  233. with open(file_path, "wb") as f:
  234. f.write(img_data)
  235. return image_id
  236. except Exception as e:
  237. print(f"Error saving image: {e}")
  238. return None
  239. def save_url_image(url):
  240. image_id = str(uuid.uuid4())
  241. file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
  242. try:
  243. r = requests.get(url)
  244. r.raise_for_status()
  245. with open(file_path, "wb") as image_file:
  246. image_file.write(r.content)
  247. return image_id
  248. except Exception as e:
  249. print(f"Error saving image: {e}")
  250. return None
  251. @app.post("/generations")
  252. def generate_image(
  253. form_data: GenerateImageForm,
  254. user=Depends(get_current_user),
  255. ):
  256. width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
  257. r = None
  258. try:
  259. if app.state.ENGINE == "openai":
  260. headers = {}
  261. headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
  262. headers["Content-Type"] = "application/json"
  263. data = {
  264. "model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2",
  265. "prompt": form_data.prompt,
  266. "n": form_data.n,
  267. "size": form_data.size if form_data.size else app.state.IMAGE_SIZE,
  268. "response_format": "b64_json",
  269. }
  270. r = requests.post(
  271. url=f"https://api.openai.com/v1/images/generations",
  272. json=data,
  273. headers=headers,
  274. )
  275. r.raise_for_status()
  276. res = r.json()
  277. images = []
  278. for image in res["data"]:
  279. image_id = save_b64_image(image["b64_json"])
  280. images.append({"url": f"/cache/image/generations/{image_id}.png"})
  281. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
  282. with open(file_body_path, "w") as f:
  283. json.dump(data, f)
  284. return images
  285. elif app.state.ENGINE == "comfyui":
  286. data = {
  287. "prompt": form_data.prompt,
  288. "width": width,
  289. "height": height,
  290. "n": form_data.n,
  291. }
  292. if app.state.IMAGE_STEPS != None:
  293. data["steps"] = app.state.IMAGE_STEPS
  294. if form_data.negative_prompt != None:
  295. data["negative_prompt"] = form_data.negative_prompt
  296. data = ImageGenerationPayload(**data)
  297. res = comfyui_generate_image(
  298. app.state.MODEL,
  299. data,
  300. user.id,
  301. app.state.COMFYUI_BASE_URL,
  302. )
  303. print(res)
  304. images = []
  305. for image in res["data"]:
  306. image_id = save_url_image(image["url"])
  307. images.append({"url": f"/cache/image/generations/{image_id}.png"})
  308. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
  309. with open(file_body_path, "w") as f:
  310. json.dump(data.model_dump(exclude_none=True), f)
  311. print(images)
  312. return images
  313. else:
  314. if form_data.model:
  315. set_model_handler(form_data.model)
  316. data = {
  317. "prompt": form_data.prompt,
  318. "batch_size": form_data.n,
  319. "width": width,
  320. "height": height,
  321. }
  322. if app.state.IMAGE_STEPS != None:
  323. data["steps"] = app.state.IMAGE_STEPS
  324. if form_data.negative_prompt != None:
  325. data["negative_prompt"] = form_data.negative_prompt
  326. r = requests.post(
  327. url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
  328. json=data,
  329. )
  330. res = r.json()
  331. print(res)
  332. images = []
  333. for image in res["images"]:
  334. image_id = save_b64_image(image)
  335. images.append({"url": f"/cache/image/generations/{image_id}.png"})
  336. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
  337. with open(file_body_path, "w") as f:
  338. json.dump({**data, "info": res["info"]}, f)
  339. return images
  340. except Exception as e:
  341. error = e
  342. if r != None:
  343. data = r.json()
  344. if "error" in data:
  345. error = data["error"]["message"]
  346. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))