|
@@ -42,10 +42,10 @@ router = APIRouter()
|
|
|
async def get_config(request: Request, user=Depends(get_admin_user)):
|
|
|
return {
|
|
|
"enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
|
|
|
- "engine": request.app.state.config.ENGINE,
|
|
|
+ "engine": request.app.state.config.IMAGE_GENERATION_ENGINE,
|
|
|
"openai": {
|
|
|
- "OPENAI_API_BASE_URL": request.app.state.config.OPENAI_API_BASE_URL,
|
|
|
- "OPENAI_API_KEY": request.app.state.config.OPENAI_API_KEY,
|
|
|
+ "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
|
|
|
+ "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
|
|
|
},
|
|
|
"automatic1111": {
|
|
|
"AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
|
|
@@ -93,11 +93,13 @@ class ConfigForm(BaseModel):
|
|
|
async def update_config(
|
|
|
request: Request, form_data: ConfigForm, user=Depends(get_admin_user)
|
|
|
):
|
|
|
- request.app.state.config.ENGINE = form_data.engine
|
|
|
+ request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.engine
|
|
|
request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.enabled
|
|
|
|
|
|
- request.app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL
|
|
|
- request.app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
|
|
|
+ request.app.state.config.IMAGES_OPENAI_API_BASE_URL = (
|
|
|
+ form_data.openai.OPENAI_API_BASE_URL
|
|
|
+ )
|
|
|
+ request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
|
|
|
|
|
|
request.app.state.config.AUTOMATIC1111_BASE_URL = (
|
|
|
form_data.automatic1111.AUTOMATIC1111_BASE_URL
|
|
@@ -132,10 +134,10 @@ async def update_config(
|
|
|
|
|
|
return {
|
|
|
"enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
|
|
|
- "engine": request.app.state.config.ENGINE,
|
|
|
+ "engine": request.app.state.config.IMAGE_GENERATION_ENGINE,
|
|
|
"openai": {
|
|
|
- "OPENAI_API_BASE_URL": request.app.state.config.OPENAI_API_BASE_URL,
|
|
|
- "OPENAI_API_KEY": request.app.state.config.OPENAI_API_KEY,
|
|
|
+ "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
|
|
|
+ "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
|
|
|
},
|
|
|
"automatic1111": {
|
|
|
"AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
|
|
@@ -166,7 +168,7 @@ def get_automatic1111_api_auth(request: Request):
|
|
|
|
|
|
@router.get("/config/url/verify")
|
|
|
async def verify_url(request: Request, user=Depends(get_admin_user)):
|
|
|
- if request.app.state.config.ENGINE == "automatic1111":
|
|
|
+ if request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111":
|
|
|
try:
|
|
|
r = requests.get(
|
|
|
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
|
|
@@ -177,7 +179,7 @@ async def verify_url(request: Request, user=Depends(get_admin_user)):
|
|
|
except Exception:
|
|
|
request.app.state.config.ENABLE_IMAGE_GENERATION = False
|
|
|
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
|
|
|
- elif request.app.state.config.ENGINE == "comfyui":
|
|
|
+ elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
|
|
try:
|
|
|
r = requests.get(
|
|
|
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info"
|
|
@@ -194,7 +196,7 @@ async def verify_url(request: Request, user=Depends(get_admin_user)):
|
|
|
def set_image_model(request: Request, model: str):
|
|
|
log.info(f"Setting image model to {model}")
|
|
|
request.app.state.config.MODEL = model
|
|
|
- if request.app.state.config.ENGINE in ["", "automatic1111"]:
|
|
|
+ if request.app.state.config.IMAGE_GENERATION_ENGINE in ["", "automatic1111"]:
|
|
|
api_auth = get_automatic1111_api_auth()
|
|
|
r = requests.get(
|
|
|
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
|
|
@@ -212,17 +214,17 @@ def set_image_model(request: Request, model: str):
|
|
|
|
|
|
|
|
|
def get_image_model():
|
|
|
- if request.app.state.config.ENGINE == "openai":
|
|
|
+ if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
|
|
|
return (
|
|
|
request.app.state.config.MODEL
|
|
|
if request.app.state.config.MODEL
|
|
|
else "dall-e-2"
|
|
|
)
|
|
|
- elif request.app.state.config.ENGINE == "comfyui":
|
|
|
+ elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
|
|
return request.app.state.config.MODEL if request.app.state.config.MODEL else ""
|
|
|
elif (
|
|
|
- request.app.state.config.ENGINE == "automatic1111"
|
|
|
- or request.app.state.config.ENGINE == ""
|
|
|
+ request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
|
|
|
+ or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
|
|
|
):
|
|
|
try:
|
|
|
r = requests.get(
|
|
@@ -285,12 +287,12 @@ async def update_image_config(
|
|
|
@router.get("/models")
|
|
|
def get_models(request: Request, user=Depends(get_verified_user)):
|
|
|
try:
|
|
|
- if request.app.state.config.ENGINE == "openai":
|
|
|
+ if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
|
|
|
return [
|
|
|
{"id": "dall-e-2", "name": "DALL·E 2"},
|
|
|
{"id": "dall-e-3", "name": "DALL·E 3"},
|
|
|
]
|
|
|
- elif request.app.state.config.ENGINE == "comfyui":
|
|
|
+ elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
|
|
# TODO - get models from comfyui
|
|
|
r = requests.get(
|
|
|
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info"
|
|
@@ -336,8 +338,8 @@ def get_models(request: Request, user=Depends(get_verified_user)):
|
|
|
)
|
|
|
)
|
|
|
elif (
|
|
|
- request.app.state.config.ENGINE == "automatic1111"
|
|
|
- or request.app.state.config.ENGINE == ""
|
|
|
+ request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
|
|
|
+ or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
|
|
|
):
|
|
|
r = requests.get(
|
|
|
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
|
|
@@ -433,10 +435,10 @@ async def image_generations(
|
|
|
|
|
|
r = None
|
|
|
try:
|
|
|
- if request.app.state.config.ENGINE == "openai":
|
|
|
+ if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
|
|
|
headers = {}
|
|
|
headers["Authorization"] = (
|
|
|
- f"Bearer {request.app.state.config.OPENAI_API_KEY}"
|
|
|
+ f"Bearer {request.app.state.config.IMAGES_OPENAI_API_KEY}"
|
|
|
)
|
|
|
headers["Content-Type"] = "application/json"
|
|
|
|
|
@@ -465,7 +467,7 @@ async def image_generations(
|
|
|
# Use asyncio.to_thread for the requests.post call
|
|
|
r = await asyncio.to_thread(
|
|
|
requests.post,
|
|
|
- url=f"{request.app.state.config.OPENAI_API_BASE_URL}/images/generations",
|
|
|
+ url=f"{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/generations",
|
|
|
json=data,
|
|
|
headers=headers,
|
|
|
)
|
|
@@ -485,7 +487,7 @@ async def image_generations(
|
|
|
|
|
|
return images
|
|
|
|
|
|
- elif request.app.state.config.ENGINE == "comfyui":
|
|
|
+ elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
|
|
data = {
|
|
|
"prompt": form_data.prompt,
|
|
|
"width": width,
|
|
@@ -531,8 +533,8 @@ async def image_generations(
|
|
|
log.debug(f"images: {images}")
|
|
|
return images
|
|
|
elif (
|
|
|
- request.app.state.config.ENGINE == "automatic1111"
|
|
|
- or request.app.state.config.ENGINE == ""
|
|
|
+ request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
|
|
|
+ or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
|
|
|
):
|
|
|
if form_data.model:
|
|
|
set_image_model(form_data.model)
|