main.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. import re
  2. import requests
  3. from fastapi import (
  4. FastAPI,
  5. Request,
  6. Depends,
  7. HTTPException,
  8. status,
  9. UploadFile,
  10. File,
  11. Form,
  12. )
  13. from fastapi.middleware.cors import CORSMiddleware
  14. from faster_whisper import WhisperModel
  15. from constants import ERROR_MESSAGES
  16. from utils.utils import (
  17. get_current_user,
  18. get_admin_user,
  19. )
  20. from utils.misc import calculate_sha256
  21. from typing import Optional
  22. from pydantic import BaseModel
  23. from config import AUTOMATIC1111_BASE_URL
  24. app = FastAPI()
  25. app.add_middleware(
  26. CORSMiddleware,
  27. allow_origins=["*"],
  28. allow_credentials=True,
  29. allow_methods=["*"],
  30. allow_headers=["*"],
  31. )
  32. app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  33. app.state.ENABLED = app.state.AUTOMATIC1111_BASE_URL != ""
  34. app.state.IMAGE_SIZE = "512x512"
  35. @app.get("/enabled", response_model=bool)
  36. async def get_enable_status(request: Request, user=Depends(get_admin_user)):
  37. return app.state.ENABLED
  38. @app.get("/enabled/toggle", response_model=bool)
  39. async def toggle_enabled(request: Request, user=Depends(get_admin_user)):
  40. try:
  41. r = requests.head(app.state.AUTOMATIC1111_BASE_URL)
  42. app.state.ENABLED = not app.state.ENABLED
  43. return app.state.ENABLED
  44. except Exception as e:
  45. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  46. class UrlUpdateForm(BaseModel):
  47. url: str
  48. @app.get("/url")
  49. async def get_openai_url(user=Depends(get_admin_user)):
  50. return {"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL}
  51. @app.post("/url/update")
  52. async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
  53. if form_data.url == "":
  54. app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  55. else:
  56. app.state.AUTOMATIC1111_BASE_URL = form_data.url.strip("/")
  57. return {
  58. "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
  59. "status": True,
  60. }
  61. class ImageSizeUpdateForm(BaseModel):
  62. size: str
  63. @app.get("/size")
  64. async def get_image_size(user=Depends(get_admin_user)):
  65. return {"IMAGE_SIZE": app.state.IMAGE_SIZE}
  66. @app.post("/size/update")
  67. async def update_image_size(
  68. form_data: ImageSizeUpdateForm, user=Depends(get_admin_user)
  69. ):
  70. pattern = r"^\d+x\d+$" # Regular expression pattern
  71. if re.match(pattern, form_data.size):
  72. app.state.IMAGE_SIZE = form_data.size
  73. return {
  74. "IMAGE_SIZE": app.state.IMAGE_SIZE,
  75. "status": True,
  76. }
  77. else:
  78. raise HTTPException(
  79. status_code=400,
  80. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
  81. )
  82. @app.get("/models")
  83. def get_models(user=Depends(get_current_user)):
  84. try:
  85. r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models")
  86. models = r.json()
  87. return models
  88. except Exception as e:
  89. app.state.ENABLED = False
  90. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  91. @app.get("/models/default")
  92. async def get_default_model(user=Depends(get_admin_user)):
  93. try:
  94. r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
  95. options = r.json()
  96. return {"model": options["sd_model_checkpoint"]}
  97. except Exception as e:
  98. app.state.ENABLED = False
  99. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  100. class UpdateModelForm(BaseModel):
  101. model: str
  102. def set_model_handler(model: str):
  103. r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
  104. options = r.json()
  105. if model != options["sd_model_checkpoint"]:
  106. options["sd_model_checkpoint"] = model
  107. r = requests.post(
  108. url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options
  109. )
  110. return options
  111. @app.post("/models/default/update")
  112. def update_default_model(
  113. form_data: UpdateModelForm,
  114. user=Depends(get_current_user),
  115. ):
  116. return set_model_handler(form_data.model)
  117. class GenerateImageForm(BaseModel):
  118. model: Optional[str] = None
  119. prompt: str
  120. n: int = 1
  121. size: str = "512x512"
  122. negative_prompt: Optional[str] = None
  123. @app.post("/generations")
  124. def generate_image(
  125. form_data: GenerateImageForm,
  126. user=Depends(get_current_user),
  127. ):
  128. print(form_data)
  129. try:
  130. if form_data.model:
  131. set_model_handler(form_data.model)
  132. width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
  133. data = {
  134. "prompt": form_data.prompt,
  135. "batch_size": form_data.n,
  136. "width": width,
  137. "height": height,
  138. }
  139. if form_data.negative_prompt != None:
  140. data["negative_prompt"] = form_data.negative_prompt
  141. print(data)
  142. r = requests.post(
  143. url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
  144. json=data,
  145. )
  146. return r.json()
  147. except Exception as e:
  148. print(e)
  149. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))