123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161 |
- import os
- import requests
- from fastapi import (
- FastAPI,
- Request,
- Depends,
- HTTPException,
- status,
- UploadFile,
- File,
- Form,
- )
- from fastapi.middleware.cors import CORSMiddleware
- from faster_whisper import WhisperModel
- from constants import ERROR_MESSAGES
- from utils.utils import (
- get_current_user,
- get_admin_user,
- )
- from utils.misc import calculate_sha256
- from typing import Optional
- from pydantic import BaseModel
- from config import AUTOMATIC1111_BASE_URL
- app = FastAPI()
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
- app.state.ENABLED = app.state.AUTOMATIC1111_BASE_URL != ""
- @app.get("/enabled", response_model=bool)
- async def get_enable_status(request: Request, user=Depends(get_admin_user)):
- return app.state.ENABLED
- @app.get("/enabled/toggle", response_model=bool)
- async def toggle_enabled(request: Request, user=Depends(get_admin_user)):
- app.state.ENABLED = not app.state.ENABLED
- return app.state.ENABLED
- class UrlUpdateForm(BaseModel):
- url: str
- @app.get("/url")
- async def get_openai_url(user=Depends(get_admin_user)):
- return {"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL}
- @app.post("/url/update")
- async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
- try:
- r = requests.head(form_data.url)
- if r.ok:
- app.state.AUTOMATIC1111_BASE_URL = form_data.url.strip("/")
- return {
- "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
- "status": True,
- }
- except Exception as e:
- raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))
- @app.get("/models")
- def get_models(user=Depends(get_current_user)):
- try:
- r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models")
- models = r.json()
- return models
- except Exception as e:
- raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))
- @app.get("/models/default")
- async def get_default_model(user=Depends(get_admin_user)):
- try:
- r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
- options = r.json()
- return {"model": options["sd_model_checkpoint"]}
- except Exception as e:
- raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))
- class UpdateModelForm(BaseModel):
- model: str
- def set_model_handler(model: str):
- r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
- options = r.json()
- if model != options["sd_model_checkpoint"]:
- options["sd_model_checkpoint"] = model
- r = requests.post(
- url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options
- )
- return options
- @app.post("/models/default/update")
- def update_default_model(
- form_data: UpdateModelForm,
- user=Depends(get_current_user),
- ):
- return set_model_handler(form_data.model)
- class GenerateImageForm(BaseModel):
- model: Optional[str] = None
- prompt: str
- n: int = 1
- size: str = "512x512"
- negative_prompt: Optional[str] = None
- @app.post("/generations")
- def generate_image(
- form_data: GenerateImageForm,
- user=Depends(get_current_user),
- ):
- print(form_data)
- try:
- if form_data.model:
- set_model_handler(form_data.model)
- width, height = tuple(map(int, form_data.size.split("x")))
- data = {
- "prompt": form_data.prompt,
- "batch_size": form_data.n,
- "width": width,
- "height": height,
- }
- if form_data.negative_prompt != None:
- data["negative_prompt"] = form_data.negative_prompt
- print(data)
- r = requests.post(
- url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
- json=data,
- )
- return r.json()
- except Exception as e:
- print(e)
- raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))
|