|
@@ -9,21 +9,24 @@ from pathlib import Path
|
|
from typing import Optional
|
|
from typing import Optional
|
|
|
|
|
|
import requests
|
|
import requests
|
|
-from open_webui.utils.images.comfyui import (
|
|
|
|
- ComfyUIGenerateImageForm,
|
|
|
|
- ComfyUIWorkflow,
|
|
|
|
- comfyui_generate_image,
|
|
|
|
-)
|
|
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter
|
|
|
|
+from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
+from pydantic import BaseModel
|
|
|
|
|
|
|
|
|
|
from open_webui.config import CACHE_DIR
|
|
from open_webui.config import CACHE_DIR
|
|
from open_webui.constants import ERROR_MESSAGES
|
|
from open_webui.constants import ERROR_MESSAGES
|
|
from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS
|
|
from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS
|
|
|
|
|
|
-from fastapi import Depends, FastAPI, HTTPException, Request
|
|
|
|
-from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
-from pydantic import BaseModel
|
|
|
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
|
|
|
+from open_webui.utils.images.comfyui import (
|
|
|
|
+ ComfyUIGenerateImageForm,
|
|
|
|
+ ComfyUIWorkflow,
|
|
|
|
+ comfyui_generate_image,
|
|
|
|
+)
|
|
|
|
+
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
log = logging.getLogger(__name__)
|
|
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
|
|
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
|
|
@@ -31,33 +34,30 @@ log.setLevel(SRC_LOG_LEVELS["IMAGES"])
|
|
IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
|
|
IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
|
|
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
|
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
-app = FastAPI(
|
|
|
|
- docs_url="/docs" if ENV == "dev" else None,
|
|
|
|
- openapi_url="/openapi.json" if ENV == "dev" else None,
|
|
|
|
- redoc_url=None,
|
|
|
|
-)
|
|
|
|
|
|
|
|
|
|
+router = APIRouter()
|
|
|
|
|
|
-@app.get("/config")
|
|
|
|
|
|
+
|
|
|
|
+@router.get("/config")
|
|
async def get_config(request: Request, user=Depends(get_admin_user)):
|
|
async def get_config(request: Request, user=Depends(get_admin_user)):
|
|
return {
|
|
return {
|
|
- "enabled": app.state.config.ENABLED,
|
|
|
|
- "engine": app.state.config.ENGINE,
|
|
|
|
|
|
+ "enabled": request.app.state.config.ENABLED,
|
|
|
|
+ "engine": request.app.state.config.ENGINE,
|
|
"openai": {
|
|
"openai": {
|
|
- "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
|
|
|
|
- "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
|
|
|
|
|
|
+ "OPENAI_API_BASE_URL": request.app.state.config.OPENAI_API_BASE_URL,
|
|
|
|
+ "OPENAI_API_KEY": request.app.state.config.OPENAI_API_KEY,
|
|
},
|
|
},
|
|
"automatic1111": {
|
|
"automatic1111": {
|
|
- "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
|
|
|
|
- "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
|
|
|
|
- "AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE,
|
|
|
|
- "AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER,
|
|
|
|
- "AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER,
|
|
|
|
|
|
+ "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
|
|
|
|
+ "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
|
|
|
|
+ "AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE,
|
|
|
|
+ "AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER,
|
|
|
|
+ "AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER,
|
|
},
|
|
},
|
|
"comfyui": {
|
|
"comfyui": {
|
|
- "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
|
|
|
|
- "COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW,
|
|
|
|
- "COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES,
|
|
|
|
|
|
+ "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
|
|
|
|
+ "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
|
|
|
|
+ "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
|
|
@@ -89,133 +89,150 @@ class ConfigForm(BaseModel):
|
|
comfyui: ComfyUIConfigForm
|
|
comfyui: ComfyUIConfigForm
|
|
|
|
|
|
|
|
|
|
-@app.post("/config/update")
|
|
|
|
-async def update_config(form_data: ConfigForm, user=Depends(get_admin_user)):
|
|
|
|
- app.state.config.ENGINE = form_data.engine
|
|
|
|
- app.state.config.ENABLED = form_data.enabled
|
|
|
|
|
|
+@router.post("/config/update")
|
|
|
|
+async def update_config(
|
|
|
|
+ request: Request, form_data: ConfigForm, user=Depends(get_admin_user)
|
|
|
|
+):
|
|
|
|
+ request.app.state.config.ENGINE = form_data.engine
|
|
|
|
+ request.app.state.config.ENABLED = form_data.enabled
|
|
|
|
|
|
- app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL
|
|
|
|
- app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
|
|
|
|
|
|
+ request.app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL
|
|
|
|
+ request.app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
|
|
|
|
|
|
- app.state.config.AUTOMATIC1111_BASE_URL = (
|
|
|
|
|
|
+ request.app.state.config.AUTOMATIC1111_BASE_URL = (
|
|
form_data.automatic1111.AUTOMATIC1111_BASE_URL
|
|
form_data.automatic1111.AUTOMATIC1111_BASE_URL
|
|
)
|
|
)
|
|
- app.state.config.AUTOMATIC1111_API_AUTH = (
|
|
|
|
|
|
+ request.app.state.config.AUTOMATIC1111_API_AUTH = (
|
|
form_data.automatic1111.AUTOMATIC1111_API_AUTH
|
|
form_data.automatic1111.AUTOMATIC1111_API_AUTH
|
|
)
|
|
)
|
|
|
|
|
|
- app.state.config.AUTOMATIC1111_CFG_SCALE = (
|
|
|
|
|
|
+ request.app.state.config.AUTOMATIC1111_CFG_SCALE = (
|
|
float(form_data.automatic1111.AUTOMATIC1111_CFG_SCALE)
|
|
float(form_data.automatic1111.AUTOMATIC1111_CFG_SCALE)
|
|
if form_data.automatic1111.AUTOMATIC1111_CFG_SCALE
|
|
if form_data.automatic1111.AUTOMATIC1111_CFG_SCALE
|
|
else None
|
|
else None
|
|
)
|
|
)
|
|
- app.state.config.AUTOMATIC1111_SAMPLER = (
|
|
|
|
|
|
+ request.app.state.config.AUTOMATIC1111_SAMPLER = (
|
|
form_data.automatic1111.AUTOMATIC1111_SAMPLER
|
|
form_data.automatic1111.AUTOMATIC1111_SAMPLER
|
|
if form_data.automatic1111.AUTOMATIC1111_SAMPLER
|
|
if form_data.automatic1111.AUTOMATIC1111_SAMPLER
|
|
else None
|
|
else None
|
|
)
|
|
)
|
|
- app.state.config.AUTOMATIC1111_SCHEDULER = (
|
|
|
|
|
|
+ request.app.state.config.AUTOMATIC1111_SCHEDULER = (
|
|
form_data.automatic1111.AUTOMATIC1111_SCHEDULER
|
|
form_data.automatic1111.AUTOMATIC1111_SCHEDULER
|
|
if form_data.automatic1111.AUTOMATIC1111_SCHEDULER
|
|
if form_data.automatic1111.AUTOMATIC1111_SCHEDULER
|
|
else None
|
|
else None
|
|
)
|
|
)
|
|
|
|
|
|
- app.state.config.COMFYUI_BASE_URL = form_data.comfyui.COMFYUI_BASE_URL.strip("/")
|
|
|
|
- app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
|
|
|
|
- app.state.config.COMFYUI_WORKFLOW_NODES = form_data.comfyui.COMFYUI_WORKFLOW_NODES
|
|
|
|
|
|
+ request.app.state.config.COMFYUI_BASE_URL = (
|
|
|
|
+ form_data.comfyui.COMFYUI_BASE_URL.strip("/")
|
|
|
|
+ )
|
|
|
|
+ request.app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
|
|
|
|
+ request.app.state.config.COMFYUI_WORKFLOW_NODES = (
|
|
|
|
+ form_data.comfyui.COMFYUI_WORKFLOW_NODES
|
|
|
|
+ )
|
|
|
|
|
|
return {
|
|
return {
|
|
- "enabled": app.state.config.ENABLED,
|
|
|
|
- "engine": app.state.config.ENGINE,
|
|
|
|
|
|
+ "enabled": request.app.state.config.ENABLED,
|
|
|
|
+ "engine": request.app.state.config.ENGINE,
|
|
"openai": {
|
|
"openai": {
|
|
- "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
|
|
|
|
- "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
|
|
|
|
|
|
+ "OPENAI_API_BASE_URL": request.app.state.config.OPENAI_API_BASE_URL,
|
|
|
|
+ "OPENAI_API_KEY": request.app.state.config.OPENAI_API_KEY,
|
|
},
|
|
},
|
|
"automatic1111": {
|
|
"automatic1111": {
|
|
- "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
|
|
|
|
- "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
|
|
|
|
- "AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE,
|
|
|
|
- "AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER,
|
|
|
|
- "AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER,
|
|
|
|
|
|
+ "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
|
|
|
|
+ "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
|
|
|
|
+ "AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE,
|
|
|
|
+ "AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER,
|
|
|
|
+ "AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER,
|
|
},
|
|
},
|
|
"comfyui": {
|
|
"comfyui": {
|
|
- "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
|
|
|
|
- "COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW,
|
|
|
|
- "COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES,
|
|
|
|
|
|
+ "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
|
|
|
|
+ "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
|
|
|
|
+ "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
-def get_automatic1111_api_auth():
|
|
|
|
- if app.state.config.AUTOMATIC1111_API_AUTH is None:
|
|
|
|
|
|
+def get_automatic1111_api_auth(request: Request):
|
|
|
|
+ if request.app.state.config.AUTOMATIC1111_API_AUTH is None:
|
|
return ""
|
|
return ""
|
|
else:
|
|
else:
|
|
- auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8")
|
|
|
|
|
|
+ auth1111_byte_string = request.app.state.config.AUTOMATIC1111_API_AUTH.encode(
|
|
|
|
+ "utf-8"
|
|
|
|
+ )
|
|
auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
|
|
auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
|
|
auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
|
|
auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
|
|
return f"Basic {auth1111_base64_encoded_string}"
|
|
return f"Basic {auth1111_base64_encoded_string}"
|
|
|
|
|
|
|
|
|
|
-@app.get("/config/url/verify")
|
|
|
|
-async def verify_url(user=Depends(get_admin_user)):
|
|
|
|
- if app.state.config.ENGINE == "automatic1111":
|
|
|
|
|
|
+@router.get("/config/url/verify")
|
|
|
|
+async def verify_url(request: Request, user=Depends(get_admin_user)):
|
|
|
|
+ if request.app.state.config.ENGINE == "automatic1111":
|
|
try:
|
|
try:
|
|
r = requests.get(
|
|
r = requests.get(
|
|
- url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
|
|
|
|
- headers={"authorization": get_automatic1111_api_auth()},
|
|
|
|
|
|
+ url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
|
|
|
|
+ headers={"authorization": get_automatic1111_api_auth(request)},
|
|
)
|
|
)
|
|
r.raise_for_status()
|
|
r.raise_for_status()
|
|
return True
|
|
return True
|
|
except Exception:
|
|
except Exception:
|
|
- app.state.config.ENABLED = False
|
|
|
|
|
|
+ request.app.state.config.ENABLED = False
|
|
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
|
|
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
|
|
- elif app.state.config.ENGINE == "comfyui":
|
|
|
|
|
|
+ elif request.app.state.config.ENGINE == "comfyui":
|
|
try:
|
|
try:
|
|
- r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
|
|
|
|
|
|
+ r = requests.get(
|
|
|
|
+ url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info"
|
|
|
|
+ )
|
|
r.raise_for_status()
|
|
r.raise_for_status()
|
|
return True
|
|
return True
|
|
except Exception:
|
|
except Exception:
|
|
- app.state.config.ENABLED = False
|
|
|
|
|
|
+ request.app.state.config.ENABLED = False
|
|
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
|
|
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
|
|
else:
|
|
else:
|
|
return True
|
|
return True
|
|
|
|
|
|
|
|
|
|
-def set_image_model(model: str):
|
|
|
|
|
|
+def set_image_model(request: Request, model: str):
|
|
log.info(f"Setting image model to {model}")
|
|
log.info(f"Setting image model to {model}")
|
|
- app.state.config.MODEL = model
|
|
|
|
- if app.state.config.ENGINE in ["", "automatic1111"]:
|
|
|
|
|
|
+ request.app.state.config.MODEL = model
|
|
|
|
+ if request.app.state.config.ENGINE in ["", "automatic1111"]:
|
|
api_auth = get_automatic1111_api_auth()
|
|
api_auth = get_automatic1111_api_auth()
|
|
r = requests.get(
|
|
r = requests.get(
|
|
- url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
|
|
|
|
|
|
+ url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
|
|
headers={"authorization": api_auth},
|
|
headers={"authorization": api_auth},
|
|
)
|
|
)
|
|
options = r.json()
|
|
options = r.json()
|
|
if model != options["sd_model_checkpoint"]:
|
|
if model != options["sd_model_checkpoint"]:
|
|
options["sd_model_checkpoint"] = model
|
|
options["sd_model_checkpoint"] = model
|
|
r = requests.post(
|
|
r = requests.post(
|
|
- url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
|
|
|
|
|
|
+ url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
|
|
json=options,
|
|
json=options,
|
|
headers={"authorization": api_auth},
|
|
headers={"authorization": api_auth},
|
|
)
|
|
)
|
|
- return app.state.config.MODEL
|
|
|
|
|
|
+ return request.app.state.config.MODEL
|
|
|
|
|
|
|
|
|
|
def get_image_model():
|
|
def get_image_model():
|
|
- if app.state.config.ENGINE == "openai":
|
|
|
|
- return app.state.config.MODEL if app.state.config.MODEL else "dall-e-2"
|
|
|
|
- elif app.state.config.ENGINE == "comfyui":
|
|
|
|
- return app.state.config.MODEL if app.state.config.MODEL else ""
|
|
|
|
- elif app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "":
|
|
|
|
|
|
+ if request.app.state.config.ENGINE == "openai":
|
|
|
|
+ return (
|
|
|
|
+ request.app.state.config.MODEL
|
|
|
|
+ if request.app.state.config.MODEL
|
|
|
|
+ else "dall-e-2"
|
|
|
|
+ )
|
|
|
|
+ elif request.app.state.config.ENGINE == "comfyui":
|
|
|
|
+ return request.app.state.config.MODEL if request.app.state.config.MODEL else ""
|
|
|
|
+ elif (
|
|
|
|
+ request.app.state.config.ENGINE == "automatic1111"
|
|
|
|
+ or request.app.state.config.ENGINE == ""
|
|
|
|
+ ):
|
|
try:
|
|
try:
|
|
r = requests.get(
|
|
r = requests.get(
|
|
- url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
|
|
|
|
|
|
+ url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
|
|
headers={"authorization": get_automatic1111_api_auth()},
|
|
headers={"authorization": get_automatic1111_api_auth()},
|
|
)
|
|
)
|
|
options = r.json()
|
|
options = r.json()
|
|
return options["sd_model_checkpoint"]
|
|
return options["sd_model_checkpoint"]
|
|
except Exception as e:
|
|
except Exception as e:
|
|
- app.state.config.ENABLED = False
|
|
|
|
|
|
+ request.app.state.config.ENABLED = False
|
|
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
|
|
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
|
|
|
|
|
|
|
|
|
|
@@ -225,23 +242,25 @@ class ImageConfigForm(BaseModel):
|
|
IMAGE_STEPS: int
|
|
IMAGE_STEPS: int
|
|
|
|
|
|
|
|
|
|
-@app.get("/image/config")
|
|
|
|
-async def get_image_config(user=Depends(get_admin_user)):
|
|
|
|
|
|
+@router.get("/image/config")
|
|
|
|
+async def get_image_config(request: Request, user=Depends(get_admin_user)):
|
|
return {
|
|
return {
|
|
- "MODEL": app.state.config.MODEL,
|
|
|
|
- "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
|
|
|
|
- "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
|
|
|
|
|
|
+ "MODEL": request.app.state.config.MODEL,
|
|
|
|
+ "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
|
|
|
|
+ "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
-@app.post("/image/config/update")
|
|
|
|
-async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin_user)):
|
|
|
|
|
|
+@router.post("/image/config/update")
|
|
|
|
+async def update_image_config(
|
|
|
|
+ request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user)
|
|
|
|
+):
|
|
|
|
|
|
- set_image_model(form_data.MODEL)
|
|
|
|
|
|
+ set_image_model(request, form_data.MODEL)
|
|
|
|
|
|
pattern = r"^\d+x\d+$"
|
|
pattern = r"^\d+x\d+$"
|
|
if re.match(pattern, form_data.IMAGE_SIZE):
|
|
if re.match(pattern, form_data.IMAGE_SIZE):
|
|
- app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
|
|
|
|
|
|
+ request.app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
|
|
else:
|
|
else:
|
|
raise HTTPException(
|
|
raise HTTPException(
|
|
status_code=400,
|
|
status_code=400,
|
|
@@ -249,7 +268,7 @@ async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin
|
|
)
|
|
)
|
|
|
|
|
|
if form_data.IMAGE_STEPS >= 0:
|
|
if form_data.IMAGE_STEPS >= 0:
|
|
- app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS
|
|
|
|
|
|
+ request.app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS
|
|
else:
|
|
else:
|
|
raise HTTPException(
|
|
raise HTTPException(
|
|
status_code=400,
|
|
status_code=400,
|
|
@@ -257,29 +276,31 @@ async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin
|
|
)
|
|
)
|
|
|
|
|
|
return {
|
|
return {
|
|
- "MODEL": app.state.config.MODEL,
|
|
|
|
- "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
|
|
|
|
- "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
|
|
|
|
|
|
+ "MODEL": request.app.state.config.MODEL,
|
|
|
|
+ "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
|
|
|
|
+ "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
-@app.get("/models")
|
|
|
|
-def get_models(user=Depends(get_verified_user)):
|
|
|
|
|
|
+@router.get("/models")
|
|
|
|
+def get_models(request: Request, user=Depends(get_verified_user)):
|
|
try:
|
|
try:
|
|
- if app.state.config.ENGINE == "openai":
|
|
|
|
|
|
+ if request.app.state.config.ENGINE == "openai":
|
|
return [
|
|
return [
|
|
{"id": "dall-e-2", "name": "DALL·E 2"},
|
|
{"id": "dall-e-2", "name": "DALL·E 2"},
|
|
{"id": "dall-e-3", "name": "DALL·E 3"},
|
|
{"id": "dall-e-3", "name": "DALL·E 3"},
|
|
]
|
|
]
|
|
- elif app.state.config.ENGINE == "comfyui":
|
|
|
|
|
|
+ elif request.app.state.config.ENGINE == "comfyui":
|
|
# TODO - get models from comfyui
|
|
# TODO - get models from comfyui
|
|
- r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
|
|
|
|
|
|
+ r = requests.get(
|
|
|
|
+ url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info"
|
|
|
|
+ )
|
|
info = r.json()
|
|
info = r.json()
|
|
|
|
|
|
- workflow = json.loads(app.state.config.COMFYUI_WORKFLOW)
|
|
|
|
|
|
+ workflow = json.loads(request.app.state.config.COMFYUI_WORKFLOW)
|
|
model_node_id = None
|
|
model_node_id = None
|
|
|
|
|
|
- for node in app.state.config.COMFYUI_WORKFLOW_NODES:
|
|
|
|
|
|
+ for node in request.app.state.config.COMFYUI_WORKFLOW_NODES:
|
|
if node["type"] == "model":
|
|
if node["type"] == "model":
|
|
if node["node_ids"]:
|
|
if node["node_ids"]:
|
|
model_node_id = node["node_ids"][0]
|
|
model_node_id = node["node_ids"][0]
|
|
@@ -315,10 +336,11 @@ def get_models(user=Depends(get_verified_user)):
|
|
)
|
|
)
|
|
)
|
|
)
|
|
elif (
|
|
elif (
|
|
- app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == ""
|
|
|
|
|
|
+ request.app.state.config.ENGINE == "automatic1111"
|
|
|
|
+ or request.app.state.config.ENGINE == ""
|
|
):
|
|
):
|
|
r = requests.get(
|
|
r = requests.get(
|
|
- url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
|
|
|
|
|
|
+ url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
|
|
headers={"authorization": get_automatic1111_api_auth()},
|
|
headers={"authorization": get_automatic1111_api_auth()},
|
|
)
|
|
)
|
|
models = r.json()
|
|
models = r.json()
|
|
@@ -329,7 +351,7 @@ def get_models(user=Depends(get_verified_user)):
|
|
)
|
|
)
|
|
)
|
|
)
|
|
except Exception as e:
|
|
except Exception as e:
|
|
- app.state.config.ENABLED = False
|
|
|
|
|
|
+ request.app.state.config.ENABLED = False
|
|
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
|
|
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
|
|
|
|
|
|
|
|
|
|
@@ -401,18 +423,21 @@ def save_url_image(url):
|
|
return None
|
|
return None
|
|
|
|
|
|
|
|
|
|
-@app.post("/generations")
|
|
|
|
|
|
+@router.post("/generations")
|
|
async def image_generations(
|
|
async def image_generations(
|
|
|
|
+ request: Request,
|
|
form_data: GenerateImageForm,
|
|
form_data: GenerateImageForm,
|
|
user=Depends(get_verified_user),
|
|
user=Depends(get_verified_user),
|
|
):
|
|
):
|
|
- width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x")))
|
|
|
|
|
|
+ width, height = tuple(map(int, request.app.state.config.IMAGE_SIZE.split("x")))
|
|
|
|
|
|
r = None
|
|
r = None
|
|
try:
|
|
try:
|
|
- if app.state.config.ENGINE == "openai":
|
|
|
|
|
|
+ if request.app.state.config.ENGINE == "openai":
|
|
headers = {}
|
|
headers = {}
|
|
- headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
|
|
|
|
|
|
+ headers["Authorization"] = (
|
|
|
|
+ f"Bearer {request.app.state.config.OPENAI_API_KEY}"
|
|
|
|
+ )
|
|
headers["Content-Type"] = "application/json"
|
|
headers["Content-Type"] = "application/json"
|
|
|
|
|
|
if ENABLE_FORWARD_USER_INFO_HEADERS:
|
|
if ENABLE_FORWARD_USER_INFO_HEADERS:
|
|
@@ -423,14 +448,16 @@ async def image_generations(
|
|
|
|
|
|
data = {
|
|
data = {
|
|
"model": (
|
|
"model": (
|
|
- app.state.config.MODEL
|
|
|
|
- if app.state.config.MODEL != ""
|
|
|
|
|
|
+ request.app.state.config.MODEL
|
|
|
|
+ if request.app.state.config.MODEL != ""
|
|
else "dall-e-2"
|
|
else "dall-e-2"
|
|
),
|
|
),
|
|
"prompt": form_data.prompt,
|
|
"prompt": form_data.prompt,
|
|
"n": form_data.n,
|
|
"n": form_data.n,
|
|
"size": (
|
|
"size": (
|
|
- form_data.size if form_data.size else app.state.config.IMAGE_SIZE
|
|
|
|
|
|
+ form_data.size
|
|
|
|
+ if form_data.size
|
|
|
|
+ else request.app.state.config.IMAGE_SIZE
|
|
),
|
|
),
|
|
"response_format": "b64_json",
|
|
"response_format": "b64_json",
|
|
}
|
|
}
|
|
@@ -438,7 +465,7 @@ async def image_generations(
|
|
# Use asyncio.to_thread for the requests.post call
|
|
# Use asyncio.to_thread for the requests.post call
|
|
r = await asyncio.to_thread(
|
|
r = await asyncio.to_thread(
|
|
requests.post,
|
|
requests.post,
|
|
- url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations",
|
|
|
|
|
|
+ url=f"{request.app.state.config.OPENAI_API_BASE_URL}/images/generations",
|
|
json=data,
|
|
json=data,
|
|
headers=headers,
|
|
headers=headers,
|
|
)
|
|
)
|
|
@@ -458,7 +485,7 @@ async def image_generations(
|
|
|
|
|
|
return images
|
|
return images
|
|
|
|
|
|
- elif app.state.config.ENGINE == "comfyui":
|
|
|
|
|
|
+ elif request.app.state.config.ENGINE == "comfyui":
|
|
data = {
|
|
data = {
|
|
"prompt": form_data.prompt,
|
|
"prompt": form_data.prompt,
|
|
"width": width,
|
|
"width": width,
|
|
@@ -466,8 +493,8 @@ async def image_generations(
|
|
"n": form_data.n,
|
|
"n": form_data.n,
|
|
}
|
|
}
|
|
|
|
|
|
- if app.state.config.IMAGE_STEPS is not None:
|
|
|
|
- data["steps"] = app.state.config.IMAGE_STEPS
|
|
|
|
|
|
+ if request.app.state.config.IMAGE_STEPS is not None:
|
|
|
|
+ data["steps"] = request.app.state.config.IMAGE_STEPS
|
|
|
|
|
|
if form_data.negative_prompt is not None:
|
|
if form_data.negative_prompt is not None:
|
|
data["negative_prompt"] = form_data.negative_prompt
|
|
data["negative_prompt"] = form_data.negative_prompt
|
|
@@ -476,18 +503,18 @@ async def image_generations(
|
|
**{
|
|
**{
|
|
"workflow": ComfyUIWorkflow(
|
|
"workflow": ComfyUIWorkflow(
|
|
**{
|
|
**{
|
|
- "workflow": app.state.config.COMFYUI_WORKFLOW,
|
|
|
|
- "nodes": app.state.config.COMFYUI_WORKFLOW_NODES,
|
|
|
|
|
|
+ "workflow": request.app.state.config.COMFYUI_WORKFLOW,
|
|
|
|
+ "nodes": request.app.state.config.COMFYUI_WORKFLOW_NODES,
|
|
}
|
|
}
|
|
),
|
|
),
|
|
**data,
|
|
**data,
|
|
}
|
|
}
|
|
)
|
|
)
|
|
res = await comfyui_generate_image(
|
|
res = await comfyui_generate_image(
|
|
- app.state.config.MODEL,
|
|
|
|
|
|
+ request.app.state.config.MODEL,
|
|
form_data,
|
|
form_data,
|
|
user.id,
|
|
user.id,
|
|
- app.state.config.COMFYUI_BASE_URL,
|
|
|
|
|
|
+ request.app.state.config.COMFYUI_BASE_URL,
|
|
)
|
|
)
|
|
log.debug(f"res: {res}")
|
|
log.debug(f"res: {res}")
|
|
|
|
|
|
@@ -504,7 +531,8 @@ async def image_generations(
|
|
log.debug(f"images: {images}")
|
|
log.debug(f"images: {images}")
|
|
return images
|
|
return images
|
|
elif (
|
|
elif (
|
|
- app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == ""
|
|
|
|
|
|
+ request.app.state.config.ENGINE == "automatic1111"
|
|
|
|
+ or request.app.state.config.ENGINE == ""
|
|
):
|
|
):
|
|
if form_data.model:
|
|
if form_data.model:
|
|
set_image_model(form_data.model)
|
|
set_image_model(form_data.model)
|
|
@@ -516,25 +544,25 @@ async def image_generations(
|
|
"height": height,
|
|
"height": height,
|
|
}
|
|
}
|
|
|
|
|
|
- if app.state.config.IMAGE_STEPS is not None:
|
|
|
|
- data["steps"] = app.state.config.IMAGE_STEPS
|
|
|
|
|
|
+ if request.app.state.config.IMAGE_STEPS is not None:
|
|
|
|
+ data["steps"] = request.app.state.config.IMAGE_STEPS
|
|
|
|
|
|
if form_data.negative_prompt is not None:
|
|
if form_data.negative_prompt is not None:
|
|
data["negative_prompt"] = form_data.negative_prompt
|
|
data["negative_prompt"] = form_data.negative_prompt
|
|
|
|
|
|
- if app.state.config.AUTOMATIC1111_CFG_SCALE:
|
|
|
|
- data["cfg_scale"] = app.state.config.AUTOMATIC1111_CFG_SCALE
|
|
|
|
|
|
+ if request.app.state.config.AUTOMATIC1111_CFG_SCALE:
|
|
|
|
+ data["cfg_scale"] = request.app.state.config.AUTOMATIC1111_CFG_SCALE
|
|
|
|
|
|
- if app.state.config.AUTOMATIC1111_SAMPLER:
|
|
|
|
- data["sampler_name"] = app.state.config.AUTOMATIC1111_SAMPLER
|
|
|
|
|
|
+ if request.app.state.config.AUTOMATIC1111_SAMPLER:
|
|
|
|
+ data["sampler_name"] = request.app.state.config.AUTOMATIC1111_SAMPLER
|
|
|
|
|
|
- if app.state.config.AUTOMATIC1111_SCHEDULER:
|
|
|
|
- data["scheduler"] = app.state.config.AUTOMATIC1111_SCHEDULER
|
|
|
|
|
|
+ if request.app.state.config.AUTOMATIC1111_SCHEDULER:
|
|
|
|
+ data["scheduler"] = request.app.state.config.AUTOMATIC1111_SCHEDULER
|
|
|
|
|
|
# Use asyncio.to_thread for the requests.post call
|
|
# Use asyncio.to_thread for the requests.post call
|
|
r = await asyncio.to_thread(
|
|
r = await asyncio.to_thread(
|
|
requests.post,
|
|
requests.post,
|
|
- url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
|
|
|
|
|
|
+ url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
|
|
json=data,
|
|
json=data,
|
|
headers={"authorization": get_automatic1111_api_auth()},
|
|
headers={"authorization": get_automatic1111_api_auth()},
|
|
)
|
|
)
|