main.py 17 KB


  1. import re
  2. import requests
  3. import base64
  4. from fastapi import (
  5. FastAPI,
  6. Request,
  7. Depends,
  8. HTTPException,
  9. status,
  10. UploadFile,
  11. File,
  12. Form,
  13. )
  14. from fastapi.middleware.cors import CORSMiddleware
  15. from faster_whisper import WhisperModel
  16. from constants import ERROR_MESSAGES
  17. from utils.utils import (
  18. get_current_user,
  19. get_admin_user,
  20. )
  21. from apps.images.utils.comfyui import ImageGenerationPayload, comfyui_generate_image
  22. from utils.misc import calculate_sha256
  23. from typing import Optional
  24. from pydantic import BaseModel
  25. from pathlib import Path
  26. import mimetypes
  27. import uuid
  28. import base64
  29. import json
  30. import logging
  31. from config import (
  32. SRC_LOG_LEVELS,
  33. CACHE_DIR,
  34. IMAGE_GENERATION_ENGINE,
  35. ENABLE_IMAGE_GENERATION,
  36. AUTOMATIC1111_BASE_URL,
  37. AUTOMATIC1111_API_AUTH,
  38. COMFYUI_BASE_URL,
  39. COMFYUI_CFG_SCALE,
  40. COMFYUI_SAMPLER,
  41. COMFYUI_SCHEDULER,
  42. COMFYUI_SD3,
  43. IMAGES_OPENAI_API_BASE_URL,
  44. IMAGES_OPENAI_API_KEY,
  45. IMAGE_GENERATION_MODEL,
  46. IMAGE_SIZE,
  47. IMAGE_STEPS,
  48. AppConfig,
  49. )
  50. log = logging.getLogger(__name__)
  51. log.setLevel(SRC_LOG_LEVELS["IMAGES"])
  52. IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
  53. IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  54. app = FastAPI()
  55. app.add_middleware(
  56. CORSMiddleware,
  57. allow_origins=["*"],
  58. allow_credentials=True,
  59. allow_methods=["*"],
  60. allow_headers=["*"],
  61. )
  62. app.state.config = AppConfig()
  63. app.state.config.ENGINE = IMAGE_GENERATION_ENGINE
  64. app.state.config.ENABLED = ENABLE_IMAGE_GENERATION
  65. app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
  66. app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
  67. app.state.config.MODEL = IMAGE_GENERATION_MODEL
  68. app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  69. app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
  70. app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
  71. app.state.config.IMAGE_SIZE = IMAGE_SIZE
  72. app.state.config.IMAGE_STEPS = IMAGE_STEPS
  73. app.state.config.COMFYUI_CFG_SCALE = COMFYUI_CFG_SCALE
  74. app.state.config.COMFYUI_SAMPLER = COMFYUI_SAMPLER
  75. app.state.config.COMFYUI_SCHEDULER = COMFYUI_SCHEDULER
  76. app.state.config.COMFYUI_SD3 = COMFYUI_SD3
  77. def get_automatic1111_api_auth():
  78. if app.state.config.AUTOMATIC1111_API_AUTH == None:
  79. return ""
  80. else:
  81. auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode('utf-8')
  82. auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
  83. auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode('utf-8')
  84. return f"Basic {auth1111_base64_encoded_string}"
  85. @app.get("/config")
  86. async def get_config(request: Request, user=Depends(get_admin_user)):
  87. return {
  88. "engine": app.state.config.ENGINE,
  89. "enabled": app.state.config.ENABLED,
  90. }
  91. class ConfigUpdateForm(BaseModel):
  92. engine: str
  93. enabled: bool
  94. @app.post("/config/update")
  95. async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
  96. app.state.config.ENGINE = form_data.engine
  97. app.state.config.ENABLED = form_data.enabled
  98. return {
  99. "engine": app.state.config.ENGINE,
  100. "enabled": app.state.config.ENABLED,
  101. }
  102. class EngineUrlUpdateForm(BaseModel):
  103. AUTOMATIC1111_BASE_URL: Optional[str] = None
  104. AUTOMATIC1111_API_AUTH: Optional[str] = None
  105. COMFYUI_BASE_URL: Optional[str] = None
  106. @app.get("/url")
  107. async def get_engine_url(user=Depends(get_admin_user)):
  108. return {
  109. "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
  110. "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
  111. "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
  112. }
  113. @app.post("/url/update")
  114. async def update_engine_url(
  115. form_data: EngineUrlUpdateForm, user=Depends(get_admin_user)
  116. ):
  117. if form_data.AUTOMATIC1111_BASE_URL == None:
  118. app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  119. else:
  120. url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
  121. try:
  122. r = requests.head(url)
  123. app.state.config.AUTOMATIC1111_BASE_URL = url
  124. except Exception as e:
  125. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  126. if form_data.COMFYUI_BASE_URL == None:
  127. app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
  128. else:
  129. url = form_data.COMFYUI_BASE_URL.strip("/")
  130. try:
  131. r = requests.head(url)
  132. app.state.config.COMFYUI_BASE_URL = url
  133. except Exception as e:
  134. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  135. if form_data.AUTOMATIC1111_API_AUTH == None:
  136. app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
  137. else:
  138. app.state.config.AUTOMATIC1111_API_AUTH = form_data.AUTOMATIC1111_API_AUTH
  139. return {
  140. "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
  141. "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
  142. "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
  143. "status": True,
  144. }
  145. class OpenAIConfigUpdateForm(BaseModel):
  146. url: str
  147. key: str
  148. @app.get("/openai/config")
  149. async def get_openai_config(user=Depends(get_admin_user)):
  150. return {
  151. "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
  152. "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
  153. }
  154. @app.post("/openai/config/update")
  155. async def update_openai_config(
  156. form_data: OpenAIConfigUpdateForm, user=Depends(get_admin_user)
  157. ):
  158. if form_data.key == "":
  159. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
  160. app.state.config.OPENAI_API_BASE_URL = form_data.url
  161. app.state.config.OPENAI_API_KEY = form_data.key
  162. return {
  163. "status": True,
  164. "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
  165. "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
  166. }
  167. class ImageSizeUpdateForm(BaseModel):
  168. size: str
  169. @app.get("/size")
  170. async def get_image_size(user=Depends(get_admin_user)):
  171. return {"IMAGE_SIZE": app.state.config.IMAGE_SIZE}
  172. @app.post("/size/update")
  173. async def update_image_size(
  174. form_data: ImageSizeUpdateForm, user=Depends(get_admin_user)
  175. ):
  176. pattern = r"^\d+x\d+$" # Regular expression pattern
  177. if re.match(pattern, form_data.size):
  178. app.state.config.IMAGE_SIZE = form_data.size
  179. return {
  180. "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
  181. "status": True,
  182. }
  183. else:
  184. raise HTTPException(
  185. status_code=400,
  186. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
  187. )
  188. class ImageStepsUpdateForm(BaseModel):
  189. steps: int
  190. @app.get("/steps")
  191. async def get_image_size(user=Depends(get_admin_user)):
  192. return {"IMAGE_STEPS": app.state.config.IMAGE_STEPS}
  193. @app.post("/steps/update")
  194. async def update_image_size(
  195. form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
  196. ):
  197. if form_data.steps >= 0:
  198. app.state.config.IMAGE_STEPS = form_data.steps
  199. return {
  200. "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
  201. "status": True,
  202. }
  203. else:
  204. raise HTTPException(
  205. status_code=400,
  206. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
  207. )
  208. @app.get("/models")
  209. def get_models(user=Depends(get_current_user)):
  210. try:
  211. if app.state.config.ENGINE == "openai":
  212. return [
  213. {"id": "dall-e-2", "name": "DALL·E 2"},
  214. {"id": "dall-e-3", "name": "DALL·E 3"},
  215. ]
  216. elif app.state.config.ENGINE == "comfyui":
  217. r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
  218. info = r.json()
  219. return list(
  220. map(
  221. lambda model: {"id": model, "name": model},
  222. info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0],
  223. )
  224. )
  225. else:
  226. r = requests.get(
  227. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
  228. headers={"authorization": get_automatic1111_api_auth()}
  229. )
  230. models = r.json()
  231. return list(
  232. map(
  233. lambda model: {"id": model["title"], "name": model["model_name"]},
  234. models,
  235. )
  236. )
  237. except Exception as e:
  238. app.state.config.ENABLED = False
  239. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  240. @app.get("/models/default")
  241. async def get_default_model(user=Depends(get_admin_user)):
  242. try:
  243. if app.state.config.ENGINE == "openai":
  244. return {
  245. "model": (
  246. app.state.config.MODEL if app.state.config.MODEL else "dall-e-2"
  247. )
  248. }
  249. elif app.state.config.ENGINE == "comfyui":
  250. return {"model": (app.state.config.MODEL if app.state.config.MODEL else "")}
  251. else:
  252. r = requests.get(
  253. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  254. headers={"authorization": get_automatic1111_api_auth()}
  255. )
  256. options = r.json()
  257. return {"model": options["sd_model_checkpoint"]}
  258. except Exception as e:
  259. app.state.config.ENABLED = False
  260. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  261. class UpdateModelForm(BaseModel):
  262. model: str
  263. def set_model_handler(model: str):
  264. if app.state.config.ENGINE in ["openai", "comfyui"]:
  265. app.state.config.MODEL = model
  266. return app.state.config.MODEL
  267. else:
  268. api_auth = get_automatic1111_api_auth()
  269. r = requests.get(
  270. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  271. headers={"authorization": api_auth}
  272. )
  273. options = r.json()
  274. if model != options["sd_model_checkpoint"]:
  275. options["sd_model_checkpoint"] = model
  276. r = requests.post(
  277. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  278. json=options,
  279. headers={"authorization": api_auth},
  280. )
  281. return options
  282. @app.post("/models/default/update")
  283. def update_default_model(
  284. form_data: UpdateModelForm,
  285. user=Depends(get_current_user),
  286. ):
  287. return set_model_handler(form_data.model)
  288. class GenerateImageForm(BaseModel):
  289. model: Optional[str] = None
  290. prompt: str
  291. n: int = 1
  292. size: Optional[str] = None
  293. negative_prompt: Optional[str] = None
  294. def save_b64_image(b64_str):
  295. try:
  296. image_id = str(uuid.uuid4())
  297. if "," in b64_str:
  298. header, encoded = b64_str.split(",", 1)
  299. mime_type = header.split(";")[0]
  300. img_data = base64.b64decode(encoded)
  301. image_format = mimetypes.guess_extension(mime_type)
  302. image_filename = f"{image_id}{image_format}"
  303. file_path = IMAGE_CACHE_DIR / f"{image_filename}"
  304. with open(file_path, "wb") as f:
  305. f.write(img_data)
  306. return image_filename
  307. else:
  308. image_filename = f"{image_id}.png"
  309. file_path = IMAGE_CACHE_DIR.joinpath(image_filename)
  310. img_data = base64.b64decode(b64_str)
  311. # Write the image data to a file
  312. with open(file_path, "wb") as f:
  313. f.write(img_data)
  314. return image_filename
  315. except Exception as e:
  316. log.exception(f"Error saving image: {e}")
  317. return None
  318. def save_url_image(url):
  319. image_id = str(uuid.uuid4())
  320. try:
  321. r = requests.get(url)
  322. r.raise_for_status()
  323. if r.headers["content-type"].split("/")[0] == "image":
  324. mime_type = r.headers["content-type"]
  325. image_format = mimetypes.guess_extension(mime_type)
  326. if not image_format:
  327. raise ValueError("Could not determine image type from MIME type")
  328. image_filename = f"{image_id}{image_format}"
  329. file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
  330. with open(file_path, "wb") as image_file:
  331. for chunk in r.iter_content(chunk_size=8192):
  332. image_file.write(chunk)
  333. return image_filename
  334. else:
  335. log.error(f"Url does not point to an image.")
  336. return None
  337. except Exception as e:
  338. log.exception(f"Error saving image: {e}")
  339. return None
  340. @app.post("/generations")
  341. def generate_image(
  342. form_data: GenerateImageForm,
  343. user=Depends(get_current_user),
  344. ):
  345. width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x")))
  346. r = None
  347. try:
  348. if app.state.config.ENGINE == "openai":
  349. headers = {}
  350. headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
  351. headers["Content-Type"] = "application/json"
  352. data = {
  353. "model": (
  354. app.state.config.MODEL
  355. if app.state.config.MODEL != ""
  356. else "dall-e-2"
  357. ),
  358. "prompt": form_data.prompt,
  359. "n": form_data.n,
  360. "size": (
  361. form_data.size if form_data.size else app.state.config.IMAGE_SIZE
  362. ),
  363. "response_format": "b64_json",
  364. }
  365. r = requests.post(
  366. url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations",
  367. json=data,
  368. headers=headers,
  369. )
  370. r.raise_for_status()
  371. res = r.json()
  372. images = []
  373. for image in res["data"]:
  374. image_filename = save_b64_image(image["b64_json"])
  375. images.append({"url": f"/cache/image/generations/{image_filename}"})
  376. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  377. with open(file_body_path, "w") as f:
  378. json.dump(data, f)
  379. return images
  380. elif app.state.config.ENGINE == "comfyui":
  381. data = {
  382. "prompt": form_data.prompt,
  383. "width": width,
  384. "height": height,
  385. "n": form_data.n,
  386. }
  387. if app.state.config.IMAGE_STEPS is not None:
  388. data["steps"] = app.state.config.IMAGE_STEPS
  389. if form_data.negative_prompt is not None:
  390. data["negative_prompt"] = form_data.negative_prompt
  391. if app.state.config.COMFYUI_CFG_SCALE:
  392. data["cfg_scale"] = app.state.config.COMFYUI_CFG_SCALE
  393. if app.state.config.COMFYUI_SAMPLER is not None:
  394. data["sampler"] = app.state.config.COMFYUI_SAMPLER
  395. if app.state.config.COMFYUI_SCHEDULER is not None:
  396. data["scheduler"] = app.state.config.COMFYUI_SCHEDULER
  397. if app.state.config.COMFYUI_SD3 is not None:
  398. data["sd3"] = app.state.config.COMFYUI_SD3
  399. data = ImageGenerationPayload(**data)
  400. res = comfyui_generate_image(
  401. app.state.config.MODEL,
  402. data,
  403. user.id,
  404. app.state.config.COMFYUI_BASE_URL,
  405. )
  406. log.debug(f"res: {res}")
  407. images = []
  408. for image in res["data"]:
  409. image_filename = save_url_image(image["url"])
  410. images.append({"url": f"/cache/image/generations/{image_filename}"})
  411. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  412. with open(file_body_path, "w") as f:
  413. json.dump(data.model_dump(exclude_none=True), f)
  414. log.debug(f"images: {images}")
  415. return images
  416. else:
  417. if form_data.model:
  418. set_model_handler(form_data.model)
  419. data = {
  420. "prompt": form_data.prompt,
  421. "batch_size": form_data.n,
  422. "width": width,
  423. "height": height,
  424. }
  425. if app.state.config.IMAGE_STEPS is not None:
  426. data["steps"] = app.state.config.IMAGE_STEPS
  427. if form_data.negative_prompt is not None:
  428. data["negative_prompt"] = form_data.negative_prompt
  429. r = requests.post(
  430. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
  431. json=data,
  432. headers={"authorization": get_automatic1111_api_auth()},
  433. )
  434. res = r.json()
  435. log.debug(f"res: {res}")
  436. images = []
  437. for image in res["images"]:
  438. image_filename = save_b64_image(image)
  439. images.append({"url": f"/cache/image/generations/{image_filename}"})
  440. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  441. with open(file_body_path, "w") as f:
  442. json.dump({**data, "info": res["info"]}, f)
  443. return images
  444. except Exception as e:
  445. error = e
  446. if r != None:
  447. data = r.json()
  448. if "error" in data:
  449. error = data["error"]["message"]
  450. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))