|
@@ -17,6 +17,9 @@ from open_webui.apps.images.utils.comfyui import (
|
|
|
from open_webui.config import (
|
|
|
AUTOMATIC1111_API_AUTH,
|
|
|
AUTOMATIC1111_BASE_URL,
|
|
|
+ AUTOMATIC1111_CFG_SCALE,
|
|
|
+ AUTOMATIC1111_SAMPLER,
|
|
|
+ AUTOMATIC1111_SCHEDULER,
|
|
|
CACHE_DIR,
|
|
|
COMFYUI_BASE_URL,
|
|
|
COMFYUI_WORKFLOW,
|
|
@@ -65,6 +68,9 @@ app.state.config.MODEL = IMAGE_GENERATION_MODEL
|
|
|
|
|
|
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
|
|
|
app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
|
|
|
+app.state.config.AUTOMATIC1111_CFG_SCALE = AUTOMATIC1111_CFG_SCALE
|
|
|
+app.state.config.AUTOMATIC1111_SAMPLER = AUTOMATIC1111_SAMPLER
|
|
|
+app.state.config.AUTOMATIC1111_SCHEDULER = AUTOMATIC1111_SCHEDULER
|
|
|
app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
|
|
|
app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW
|
|
|
app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES
|
|
@@ -85,6 +91,9 @@ async def get_config(request: Request, user=Depends(get_admin_user)):
|
|
|
"automatic1111": {
|
|
|
"AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
|
|
|
"AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
|
|
|
+ "AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE,
|
|
|
+ "AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER,
|
|
|
+ "AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER,
|
|
|
},
|
|
|
"comfyui": {
|
|
|
"COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
|
|
@@ -102,6 +111,9 @@ class OpenAIConfigForm(BaseModel):
|
|
|
class Automatic1111ConfigForm(BaseModel):
|
|
|
AUTOMATIC1111_BASE_URL: str
|
|
|
AUTOMATIC1111_API_AUTH: str
|
|
|
+ AUTOMATIC1111_CFG_SCALE: float
|
|
|
+ AUTOMATIC1111_SAMPLER: str
|
|
|
+ AUTOMATIC1111_SCHEDULER: str
|
|
|
|
|
|
|
|
|
class ComfyUIConfigForm(BaseModel):
|
|
@@ -132,6 +144,12 @@ async def update_config(form_data: ConfigForm, user=Depends(get_admin_user)):
|
|
|
app.state.config.AUTOMATIC1111_API_AUTH = (
|
|
|
form_data.automatic1111.AUTOMATIC1111_API_AUTH
|
|
|
)
|
|
|
+ app.state.config.AUTOMATIC1111_CFG_SCALE = form_data.automatic1111.AUTOMATIC1111_CFG_SCALE
|
|
|
+ app.state.config.AUTOMATIC1111_SAMPLER = form_data.automatic1111.AUTOMATIC1111_SAMPLER
|
|
|
+ app.state.config.AUTOMATIC1111_SCHEDULER = (
|
|
|
+ form_data.automatic1111.AUTOMATIC1111_SCHEDULER
|
|
|
+ )
|
|
|
+
|
|
|
|
|
|
app.state.config.COMFYUI_BASE_URL = form_data.comfyui.COMFYUI_BASE_URL.strip("/")
|
|
|
app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
|
|
@@ -147,6 +165,9 @@ async def update_config(form_data: ConfigForm, user=Depends(get_admin_user)):
|
|
|
"automatic1111": {
|
|
|
"AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
|
|
|
"AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
|
|
|
+ "AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE,
|
|
|
+ "AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER,
|
|
|
+ "AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER,
|
|
|
},
|
|
|
"comfyui": {
|
|
|
"COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
|
|
@@ -266,6 +287,7 @@ async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin
|
|
|
detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
|
|
|
)
|
|
|
|
|
|
+
|
|
|
return {
|
|
|
"MODEL": app.state.config.MODEL,
|
|
|
"IMAGE_SIZE": app.state.config.IMAGE_SIZE,
|
|
@@ -523,6 +545,15 @@ async def image_generations(
|
|
|
|
|
|
if form_data.negative_prompt is not None:
|
|
|
data["negative_prompt"] = form_data.negative_prompt
|
|
|
+
|
|
|
+ if app.state.config.AUTOMATIC1111_CFG_SCALE:
|
|
|
+ data["cfg_scale"] = app.state.config.AUTOMATIC1111_CFG_SCALE
|
|
|
+
|
|
|
+ if app.state.config.AUTOMATIC1111_SAMPLER:
|
|
|
+ data["sampler_name"] = app.state.config.AUTOMATIC1111_SAMPLER
|
|
|
+
|
|
|
+ if app.state.config.AUTOMATIC1111_SCHEDULER:
|
|
|
+ data["scheduler"] = app.state.config.AUTOMATIC1111_SCHEDULER
|
|
|
|
|
|
# Use asyncio.to_thread for the requests.post call
|
|
|
r = await asyncio.to_thread(
|