main.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. import re
  2. import requests
  3. from fastapi import (
  4. FastAPI,
  5. Request,
  6. Depends,
  7. HTTPException,
  8. status,
  9. UploadFile,
  10. File,
  11. Form,
  12. )
  13. from fastapi.middleware.cors import CORSMiddleware
  14. from faster_whisper import WhisperModel
  15. from constants import ERROR_MESSAGES
  16. from utils.utils import (
  17. get_current_user,
  18. get_admin_user,
  19. )
  20. from utils.misc import calculate_sha256
  21. from typing import Optional
  22. from pydantic import BaseModel
  23. from pathlib import Path
  24. import uuid
  25. import base64
  26. import json
  27. import logging
  28. from config import SRC_LOG_LEVELS, CACHE_DIR, AUTOMATIC1111_BASE_URL
  29. log = logging.getLogger(__name__)
  30. log.setLevel(SRC_LOG_LEVELS["IMAGES"])
  31. IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
  32. IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  33. app = FastAPI()
  34. app.add_middleware(
  35. CORSMiddleware,
  36. allow_origins=["*"],
  37. allow_credentials=True,
  38. allow_methods=["*"],
  39. allow_headers=["*"],
  40. )
  41. app.state.ENGINE = ""
  42. app.state.ENABLED = False
  43. app.state.OPENAI_API_KEY = ""
  44. app.state.MODEL = ""
  45. app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  46. app.state.IMAGE_SIZE = "512x512"
  47. app.state.IMAGE_STEPS = 50
  48. @app.get("/config")
  49. async def get_config(request: Request, user=Depends(get_admin_user)):
  50. return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
  51. class ConfigUpdateForm(BaseModel):
  52. engine: str
  53. enabled: bool
  54. @app.post("/config/update")
  55. async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
  56. app.state.ENGINE = form_data.engine
  57. app.state.ENABLED = form_data.enabled
  58. return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
  59. class UrlUpdateForm(BaseModel):
  60. url: str
  61. @app.get("/url")
  62. async def get_automatic1111_url(user=Depends(get_admin_user)):
  63. return {"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL}
  64. @app.post("/url/update")
  65. async def update_automatic1111_url(
  66. form_data: UrlUpdateForm, user=Depends(get_admin_user)
  67. ):
  68. if form_data.url == "":
  69. app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  70. else:
  71. url = form_data.url.strip("/")
  72. try:
  73. r = requests.head(url)
  74. app.state.AUTOMATIC1111_BASE_URL = url
  75. except Exception as e:
  76. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  77. return {
  78. "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
  79. "status": True,
  80. }
  81. class OpenAIKeyUpdateForm(BaseModel):
  82. key: str
  83. @app.get("/key")
  84. async def get_openai_key(user=Depends(get_admin_user)):
  85. return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
  86. @app.post("/key/update")
  87. async def update_openai_key(
  88. form_data: OpenAIKeyUpdateForm, user=Depends(get_admin_user)
  89. ):
  90. if form_data.key == "":
  91. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
  92. app.state.OPENAI_API_KEY = form_data.key
  93. return {
  94. "OPENAI_API_KEY": app.state.OPENAI_API_KEY,
  95. "status": True,
  96. }
  97. class ImageSizeUpdateForm(BaseModel):
  98. size: str
  99. @app.get("/size")
  100. async def get_image_size(user=Depends(get_admin_user)):
  101. return {"IMAGE_SIZE": app.state.IMAGE_SIZE}
  102. @app.post("/size/update")
  103. async def update_image_size(
  104. form_data: ImageSizeUpdateForm, user=Depends(get_admin_user)
  105. ):
  106. pattern = r"^\d+x\d+$" # Regular expression pattern
  107. if re.match(pattern, form_data.size):
  108. app.state.IMAGE_SIZE = form_data.size
  109. return {
  110. "IMAGE_SIZE": app.state.IMAGE_SIZE,
  111. "status": True,
  112. }
  113. else:
  114. raise HTTPException(
  115. status_code=400,
  116. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
  117. )
  118. class ImageStepsUpdateForm(BaseModel):
  119. steps: int
  120. @app.get("/steps")
  121. async def get_image_size(user=Depends(get_admin_user)):
  122. return {"IMAGE_STEPS": app.state.IMAGE_STEPS}
  123. @app.post("/steps/update")
  124. async def update_image_size(
  125. form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
  126. ):
  127. if form_data.steps >= 0:
  128. app.state.IMAGE_STEPS = form_data.steps
  129. return {
  130. "IMAGE_STEPS": app.state.IMAGE_STEPS,
  131. "status": True,
  132. }
  133. else:
  134. raise HTTPException(
  135. status_code=400,
  136. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
  137. )
  138. @app.get("/models")
  139. def get_models(user=Depends(get_current_user)):
  140. try:
  141. if app.state.ENGINE == "openai":
  142. return [
  143. {"id": "dall-e-2", "name": "DALL·E 2"},
  144. {"id": "dall-e-3", "name": "DALL·E 3"},
  145. ]
  146. else:
  147. r = requests.get(
  148. url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
  149. )
  150. models = r.json()
  151. return list(
  152. map(
  153. lambda model: {"id": model["title"], "name": model["model_name"]},
  154. models,
  155. )
  156. )
  157. except Exception as e:
  158. app.state.ENABLED = False
  159. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  160. @app.get("/models/default")
  161. async def get_default_model(user=Depends(get_admin_user)):
  162. try:
  163. if app.state.ENGINE == "openai":
  164. return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"}
  165. else:
  166. r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
  167. options = r.json()
  168. return {"model": options["sd_model_checkpoint"]}
  169. except Exception as e:
  170. app.state.ENABLED = False
  171. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  172. class UpdateModelForm(BaseModel):
  173. model: str
  174. def set_model_handler(model: str):
  175. if app.state.ENGINE == "openai":
  176. app.state.MODEL = model
  177. return app.state.MODEL
  178. else:
  179. r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
  180. options = r.json()
  181. if model != options["sd_model_checkpoint"]:
  182. options["sd_model_checkpoint"] = model
  183. r = requests.post(
  184. url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options
  185. )
  186. return options
  187. @app.post("/models/default/update")
  188. def update_default_model(
  189. form_data: UpdateModelForm,
  190. user=Depends(get_current_user),
  191. ):
  192. return set_model_handler(form_data.model)
  193. class GenerateImageForm(BaseModel):
  194. model: Optional[str] = None
  195. prompt: str
  196. n: int = 1
  197. size: Optional[str] = None
  198. negative_prompt: Optional[str] = None
  199. def save_b64_image(b64_str):
  200. image_id = str(uuid.uuid4())
  201. file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
  202. try:
  203. # Split the base64 string to get the actual image data
  204. img_data = base64.b64decode(b64_str)
  205. # Write the image data to a file
  206. with open(file_path, "wb") as f:
  207. f.write(img_data)
  208. return image_id
  209. except Exception as e:
  210. log.error(f"Error saving image: {e}")
  211. return None
  212. @app.post("/generations")
  213. def generate_image(
  214. form_data: GenerateImageForm,
  215. user=Depends(get_current_user),
  216. ):
  217. r = None
  218. try:
  219. if app.state.ENGINE == "openai":
  220. headers = {}
  221. headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
  222. headers["Content-Type"] = "application/json"
  223. data = {
  224. "model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2",
  225. "prompt": form_data.prompt,
  226. "n": form_data.n,
  227. "size": form_data.size if form_data.size else app.state.IMAGE_SIZE,
  228. "response_format": "b64_json",
  229. }
  230. r = requests.post(
  231. url=f"https://api.openai.com/v1/images/generations",
  232. json=data,
  233. headers=headers,
  234. )
  235. r.raise_for_status()
  236. res = r.json()
  237. images = []
  238. for image in res["data"]:
  239. image_id = save_b64_image(image["b64_json"])
  240. images.append({"url": f"/cache/image/generations/{image_id}.png"})
  241. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
  242. with open(file_body_path, "w") as f:
  243. json.dump(data, f)
  244. return images
  245. else:
  246. if form_data.model:
  247. set_model_handler(form_data.model)
  248. width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
  249. data = {
  250. "prompt": form_data.prompt,
  251. "batch_size": form_data.n,
  252. "width": width,
  253. "height": height,
  254. }
  255. if app.state.IMAGE_STEPS != None:
  256. data["steps"] = app.state.IMAGE_STEPS
  257. if form_data.negative_prompt != None:
  258. data["negative_prompt"] = form_data.negative_prompt
  259. r = requests.post(
  260. url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
  261. json=data,
  262. )
  263. res = r.json()
  264. log.debug(f"res: {res}")
  265. images = []
  266. for image in res["images"]:
  267. image_id = save_b64_image(image)
  268. images.append({"url": f"/cache/image/generations/{image_id}.png"})
  269. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
  270. with open(file_body_path, "w") as f:
  271. json.dump({**data, "info": res["info"]}, f)
  272. return images
  273. except Exception as e:
  274. error = e
  275. if r != None:
  276. data = r.json()
  277. if "error" in data:
  278. error = data["error"]["message"]
  279. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))