main.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. import os
  2. import requests
  3. from fastapi import (
  4. FastAPI,
  5. Request,
  6. Depends,
  7. HTTPException,
  8. status,
  9. UploadFile,
  10. File,
  11. Form,
  12. )
  13. from fastapi.middleware.cors import CORSMiddleware
  14. from faster_whisper import WhisperModel
  15. from constants import ERROR_MESSAGES
  16. from utils.utils import (
  17. get_current_user,
  18. get_admin_user,
  19. )
  20. from utils.misc import calculate_sha256
  21. from typing import Optional
  22. from pydantic import BaseModel
  23. from config import AUTOMATIC1111_BASE_URL
  24. app = FastAPI()
  25. app.add_middleware(
  26. CORSMiddleware,
  27. allow_origins=["*"],
  28. allow_credentials=True,
  29. allow_methods=["*"],
  30. allow_headers=["*"],
  31. )
  32. app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  33. app.state.ENABLED = app.state.AUTOMATIC1111_BASE_URL != ""
  34. @app.get("/enabled", response_model=bool)
  35. async def get_enable_status(request: Request, user=Depends(get_admin_user)):
  36. return app.state.ENABLED
  37. @app.get("/enabled/toggle", response_model=bool)
  38. async def toggle_enabled(request: Request, user=Depends(get_admin_user)):
  39. try:
  40. r = requests.head(app.state.AUTOMATIC1111_BASE_URL)
  41. app.state.ENABLED = not app.state.ENABLED
  42. return app.state.ENABLED
  43. except Exception as e:
  44. raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))
  45. class UrlUpdateForm(BaseModel):
  46. url: str
  47. @app.get("/url")
  48. async def get_openai_url(user=Depends(get_admin_user)):
  49. return {"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL}
  50. @app.post("/url/update")
  51. async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
  52. if form_data.url == "":
  53. app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  54. else:
  55. app.state.AUTOMATIC1111_BASE_URL = form_data.url.strip("/")
  56. return {
  57. "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
  58. "status": True,
  59. }
  60. @app.get("/models")
  61. def get_models(user=Depends(get_current_user)):
  62. try:
  63. r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models")
  64. models = r.json()
  65. return models
  66. except Exception as e:
  67. raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))
  68. @app.get("/models/default")
  69. async def get_default_model(user=Depends(get_admin_user)):
  70. try:
  71. r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
  72. options = r.json()
  73. return {"model": options["sd_model_checkpoint"]}
  74. except Exception as e:
  75. raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))
  76. class UpdateModelForm(BaseModel):
  77. model: str
  78. def set_model_handler(model: str):
  79. r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
  80. options = r.json()
  81. if model != options["sd_model_checkpoint"]:
  82. options["sd_model_checkpoint"] = model
  83. r = requests.post(
  84. url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options
  85. )
  86. return options
  87. @app.post("/models/default/update")
  88. def update_default_model(
  89. form_data: UpdateModelForm,
  90. user=Depends(get_current_user),
  91. ):
  92. return set_model_handler(form_data.model)
  93. class GenerateImageForm(BaseModel):
  94. model: Optional[str] = None
  95. prompt: str
  96. n: int = 1
  97. size: str = "512x512"
  98. negative_prompt: Optional[str] = None
  99. @app.post("/generations")
  100. def generate_image(
  101. form_data: GenerateImageForm,
  102. user=Depends(get_current_user),
  103. ):
  104. print(form_data)
  105. try:
  106. if form_data.model:
  107. set_model_handler(form_data.model)
  108. width, height = tuple(map(int, form_data.size.split("x")))
  109. data = {
  110. "prompt": form_data.prompt,
  111. "batch_size": form_data.n,
  112. "width": width,
  113. "height": height,
  114. }
  115. if form_data.negative_prompt != None:
  116. data["negative_prompt"] = form_data.negative_prompt
  117. print(data)
  118. r = requests.post(
  119. url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
  120. json=data,
  121. )
  122. return r.json()
  123. except Exception as e:
  124. print(e)
  125. raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))