main.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. import re
  2. import requests
  3. from fastapi import (
  4. FastAPI,
  5. Request,
  6. Depends,
  7. HTTPException,
  8. status,
  9. UploadFile,
  10. File,
  11. Form,
  12. )
  13. from fastapi.middleware.cors import CORSMiddleware
  14. from faster_whisper import WhisperModel
  15. from constants import ERROR_MESSAGES
  16. from utils.utils import (
  17. get_current_user,
  18. get_admin_user,
  19. )
  20. from utils.misc import calculate_sha256
  21. from typing import Optional
  22. from pydantic import BaseModel
  23. from pathlib import Path
  24. import uuid
  25. import base64
  26. import json
  27. from config import CACHE_DIR, AUTOMATIC1111_BASE_URL
  28. IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
  29. IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  30. app = FastAPI()
  31. app.add_middleware(
  32. CORSMiddleware,
  33. allow_origins=["*"],
  34. allow_credentials=True,
  35. allow_methods=["*"],
  36. allow_headers=["*"],
  37. )
  38. app.state.ENGINE = ""
  39. app.state.ENABLED = False
  40. app.state.OPENAI_API_KEY = ""
  41. app.state.MODEL = ""
  42. app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  43. app.state.IMAGE_SIZE = "512x512"
  44. app.state.IMAGE_STEPS = 50
  45. @app.get("/config")
  46. async def get_config(request: Request, user=Depends(get_admin_user)):
  47. return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
  48. class ConfigUpdateForm(BaseModel):
  49. engine: str
  50. enabled: bool
  51. @app.post("/config/update")
  52. async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
  53. app.state.ENGINE = form_data.engine
  54. app.state.ENABLED = form_data.enabled
  55. return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
  56. class UrlUpdateForm(BaseModel):
  57. url: str
  58. @app.get("/url")
  59. async def get_automatic1111_url(user=Depends(get_admin_user)):
  60. return {"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL}
  61. @app.post("/url/update")
  62. async def update_automatic1111_url(
  63. form_data: UrlUpdateForm, user=Depends(get_admin_user)
  64. ):
  65. if form_data.url == "":
  66. app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  67. else:
  68. url = form_data.url.strip("/")
  69. try:
  70. r = requests.head(url)
  71. app.state.AUTOMATIC1111_BASE_URL = url
  72. except Exception as e:
  73. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  74. return {
  75. "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
  76. "status": True,
  77. }
  78. class OpenAIKeyUpdateForm(BaseModel):
  79. key: str
  80. @app.get("/key")
  81. async def get_openai_key(user=Depends(get_admin_user)):
  82. return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
  83. @app.post("/key/update")
  84. async def update_openai_key(
  85. form_data: OpenAIKeyUpdateForm, user=Depends(get_admin_user)
  86. ):
  87. if form_data.key == "":
  88. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
  89. app.state.OPENAI_API_KEY = form_data.key
  90. return {
  91. "OPENAI_API_KEY": app.state.OPENAI_API_KEY,
  92. "status": True,
  93. }
  94. class ImageSizeUpdateForm(BaseModel):
  95. size: str
  96. @app.get("/size")
  97. async def get_image_size(user=Depends(get_admin_user)):
  98. return {"IMAGE_SIZE": app.state.IMAGE_SIZE}
  99. @app.post("/size/update")
  100. async def update_image_size(
  101. form_data: ImageSizeUpdateForm, user=Depends(get_admin_user)
  102. ):
  103. pattern = r"^\d+x\d+$" # Regular expression pattern
  104. if re.match(pattern, form_data.size):
  105. app.state.IMAGE_SIZE = form_data.size
  106. return {
  107. "IMAGE_SIZE": app.state.IMAGE_SIZE,
  108. "status": True,
  109. }
  110. else:
  111. raise HTTPException(
  112. status_code=400,
  113. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
  114. )
  115. class ImageStepsUpdateForm(BaseModel):
  116. steps: int
  117. @app.get("/steps")
  118. async def get_image_size(user=Depends(get_admin_user)):
  119. return {"IMAGE_STEPS": app.state.IMAGE_STEPS}
  120. @app.post("/steps/update")
  121. async def update_image_size(
  122. form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
  123. ):
  124. if form_data.steps >= 0:
  125. app.state.IMAGE_STEPS = form_data.steps
  126. return {
  127. "IMAGE_STEPS": app.state.IMAGE_STEPS,
  128. "status": True,
  129. }
  130. else:
  131. raise HTTPException(
  132. status_code=400,
  133. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
  134. )
  135. @app.get("/models")
  136. def get_models(user=Depends(get_current_user)):
  137. try:
  138. if app.state.ENGINE == "openai":
  139. return [
  140. {"id": "dall-e-2", "name": "DALL·E 2"},
  141. {"id": "dall-e-3", "name": "DALL·E 3"},
  142. ]
  143. else:
  144. r = requests.get(
  145. url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
  146. )
  147. models = r.json()
  148. return list(
  149. map(
  150. lambda model: {"id": model["title"], "name": model["model_name"]},
  151. models,
  152. )
  153. )
  154. except Exception as e:
  155. app.state.ENABLED = False
  156. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  157. @app.get("/models/default")
  158. async def get_default_model(user=Depends(get_admin_user)):
  159. try:
  160. if app.state.ENGINE == "openai":
  161. return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"}
  162. else:
  163. r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
  164. options = r.json()
  165. return {"model": options["sd_model_checkpoint"]}
  166. except Exception as e:
  167. app.state.ENABLED = False
  168. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  169. class UpdateModelForm(BaseModel):
  170. model: str
  171. def set_model_handler(model: str):
  172. if app.state.ENGINE == "openai":
  173. app.state.MODEL = model
  174. return app.state.MODEL
  175. else:
  176. r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
  177. options = r.json()
  178. if model != options["sd_model_checkpoint"]:
  179. options["sd_model_checkpoint"] = model
  180. r = requests.post(
  181. url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options
  182. )
  183. return options
  184. @app.post("/models/default/update")
  185. def update_default_model(
  186. form_data: UpdateModelForm,
  187. user=Depends(get_current_user),
  188. ):
  189. return set_model_handler(form_data.model)
  190. class GenerateImageForm(BaseModel):
  191. model: Optional[str] = None
  192. prompt: str
  193. n: int = 1
  194. size: str = "512x512"
  195. negative_prompt: Optional[str] = None
  196. def save_b64_image(b64_str):
  197. image_id = str(uuid.uuid4())
  198. file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
  199. try:
  200. # Split the base64 string to get the actual image data
  201. img_data = base64.b64decode(b64_str)
  202. # Write the image data to a file
  203. with open(file_path, "wb") as f:
  204. f.write(img_data)
  205. return image_id
  206. except Exception as e:
  207. print(f"Error saving image: {e}")
  208. return None
  209. @app.post("/generations")
  210. def generate_image(
  211. form_data: GenerateImageForm,
  212. user=Depends(get_current_user),
  213. ):
  214. print(form_data)
  215. try:
  216. if app.state.ENGINE == "openai":
  217. headers = {}
  218. headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
  219. headers["Content-Type"] = "application/json"
  220. data = {
  221. "model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2",
  222. "prompt": form_data.prompt,
  223. "n": form_data.n,
  224. "size": form_data.size,
  225. "response_format": "b64_json",
  226. }
  227. r = requests.post(
  228. url=f"https://api.openai.com/v1/images/generations",
  229. json=data,
  230. headers=headers,
  231. )
  232. r.raise_for_status()
  233. res = r.json()
  234. images = []
  235. for image in res["data"]:
  236. image_id = save_b64_image(image["b64_json"])
  237. images.append({"url": f"/cache/image/generations/{image_id}.png"})
  238. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
  239. with open(file_body_path, "w") as f:
  240. json.dump(data, f)
  241. return images
  242. else:
  243. if form_data.model:
  244. set_model_handler(form_data.model)
  245. width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
  246. data = {
  247. "prompt": form_data.prompt,
  248. "batch_size": form_data.n,
  249. "width": width,
  250. "height": height,
  251. }
  252. if app.state.IMAGE_STEPS != None:
  253. data["steps"] = app.state.IMAGE_STEPS
  254. if form_data.negative_prompt != None:
  255. data["negative_prompt"] = form_data.negative_prompt
  256. r = requests.post(
  257. url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
  258. json=data,
  259. )
  260. res = r.json()
  261. print(res)
  262. images = []
  263. for image in res["images"]:
  264. image_id = save_b64_image(image)
  265. images.append({"url": f"/cache/image/generations/{image_id}.png"})
  266. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
  267. with open(file_body_path, "w") as f:
  268. json.dump({**data, "info": res["info"]}, f)
  269. return images
  270. except Exception as e:
  271. print(e)
  272. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))