main.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. import os
  2. import requests
  3. from fastapi import (
  4. FastAPI,
  5. Request,
  6. Depends,
  7. HTTPException,
  8. status,
  9. UploadFile,
  10. File,
  11. Form,
  12. )
  13. from fastapi.middleware.cors import CORSMiddleware
  14. from faster_whisper import WhisperModel
  15. from constants import ERROR_MESSAGES
  16. from utils.utils import (
  17. get_current_user,
  18. get_admin_user,
  19. )
  20. from utils.misc import calculate_sha256
  21. from typing import Optional
  22. from pydantic import BaseModel
  23. from config import AUTOMATIC1111_BASE_URL
  24. app = FastAPI()
  25. app.add_middleware(
  26. CORSMiddleware,
  27. allow_origins=["*"],
  28. allow_credentials=True,
  29. allow_methods=["*"],
  30. allow_headers=["*"],
  31. )
  32. app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  33. app.state.ENABLED = app.state.AUTOMATIC1111_BASE_URL != ""
  34. @app.get("/enabled", response_model=bool)
  35. async def get_enable_status(request: Request, user=Depends(get_admin_user)):
  36. return app.state.ENABLED
  37. @app.get("/enabled/toggle", response_model=bool)
  38. async def toggle_enabled(request: Request, user=Depends(get_admin_user)):
  39. app.state.ENABLED = not app.state.ENABLED
  40. return app.state.ENABLED
  41. class UrlUpdateForm(BaseModel):
  42. url: str
  43. @app.get("/url")
  44. async def get_openai_url(user=Depends(get_admin_user)):
  45. return {"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL}
  46. @app.post("/url/update")
  47. async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
  48. try:
  49. r = requests.head(form_data.url)
  50. if r.ok:
  51. app.state.AUTOMATIC1111_BASE_URL = form_data.url.strip("/")
  52. return {
  53. "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
  54. "status": True,
  55. }
  56. except Exception as e:
  57. raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))
  58. @app.get("/models")
  59. def get_models(user=Depends(get_current_user)):
  60. try:
  61. r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models")
  62. models = r.json()
  63. return models
  64. except Exception as e:
  65. raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))
  66. @app.get("/models/default")
  67. async def get_default_model(user=Depends(get_admin_user)):
  68. try:
  69. r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
  70. options = r.json()
  71. return {"model": options["sd_model_checkpoint"]}
  72. except Exception as e:
  73. raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))
  74. class UpdateModelForm(BaseModel):
  75. model: str
  76. def set_model_handler(model: str):
  77. r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
  78. options = r.json()
  79. if model != options["sd_model_checkpoint"]:
  80. options["sd_model_checkpoint"] = model
  81. r = requests.post(
  82. url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options
  83. )
  84. return options
  85. @app.post("/models/default/update")
  86. def update_default_model(
  87. form_data: UpdateModelForm,
  88. user=Depends(get_current_user),
  89. ):
  90. return set_model_handler(form_data.model)
  91. class GenerateImageForm(BaseModel):
  92. model: Optional[str] = None
  93. prompt: str
  94. n: int = 1
  95. size: str = "512x512"
  96. negative_prompt: Optional[str] = None
  97. @app.post("/generations")
  98. def generate_image(
  99. form_data: GenerateImageForm,
  100. user=Depends(get_current_user),
  101. ):
  102. print(form_data)
  103. try:
  104. if form_data.model:
  105. set_model_handler(form_data.model)
  106. width, height = tuple(map(int, form_data.size.split("x")))
  107. data = {
  108. "prompt": form_data.prompt,
  109. "batch_size": form_data.n,
  110. "width": width,
  111. "height": height,
  112. }
  113. if form_data.negative_prompt != None:
  114. data["negative_prompt"] = form_data.negative_prompt
  115. print(data)
  116. r = requests.post(
  117. url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
  118. json=data,
  119. )
  120. return r.json()
  121. except Exception as e:
  122. print(e)
  123. raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))