Browse Source

feat: custom comfyui workflow

Co-Authored-By: John Karabudak <hello@johnthenerd.com>
Timothy J. Baek 8 months ago
parent
commit
063e006446
2 changed files with 75 additions and 14 deletions
  1. 36 5
      backend/apps/images/main.py
  2. 39 9
      src/lib/components/admin/Settings/Images.svelte

+ 36 - 5
backend/apps/images/main.py

@@ -268,12 +268,43 @@ def get_models(user=Depends(get_verified_user)):
             r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
             info = r.json()
 
-            return list(
-                map(
-                    lambda model: {"id": model, "name": model},
-                    info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0],
+            workflow = json.loads(app.state.config.COMFYUI_WORKFLOW)
+            model_node_id = None
+
+            for node in app.state.config.COMFYUI_WORKFLOW_NODES:
+                if node["type"] == "model":
+                    model_node_id = node["node_ids"][0]
+                    break
+
+            if model_node_id:
+                model_list_key = None
+
+                print(workflow[model_node_id]["class_type"])
+                for key in info[workflow[model_node_id]["class_type"]]["input"][
+                    "required"
+                ]:
+                    if "_name" in key:
+                        model_list_key = key
+                        break
+
+                if model_list_key:
+                    return list(
+                        map(
+                            lambda model: {"id": model, "name": model},
+                            info[workflow[model_node_id]["class_type"]]["input"][
+                                "required"
+                            ][model_list_key][0],
+                        )
+                    )
+            else:
+                return list(
+                    map(
+                        lambda model: {"id": model, "name": model},
+                        info["CheckpointLoaderSimple"]["input"]["required"][
+                            "ckpt_name"
+                        ][0],
+                    )
                 )
-            )
         elif (
             app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == ""
         ):

+ 39 - 9
src/lib/components/admin/Settings/Images.svelte

@@ -30,27 +30,27 @@
 		{
 			type: 'prompt',
 			key: 'text',
-			node_ids: []
+			node_ids: ''
 		},
 		{
 			type: 'model',
 			key: 'ckpt_name',
-			node_ids: []
+			node_ids: ''
 		},
 		{
 			type: 'width',
 			key: 'width',
-			node_ids: []
+			node_ids: ''
 		},
 		{
 			type: 'height',
 			key: 'height',
-			node_ids: []
+			node_ids: ''
 		},
 		{
 			type: 'steps',
 			key: 'steps',
-			node_ids: []
+			node_ids: ''
 		}
 	];
 
@@ -99,6 +99,16 @@
 			}
 		}
 
+		if (config?.comfyui?.COMFYUI_WORKFLOW) {
+			config.comfyui.COMFYUI_WORKFLOW_NODES = workflowNodes.map((node) => {
+				return {
+					type: node.type,
+					key: node.key,
+					node_ids: node.node_ids.split(',').map((id) => id.trim())
+				};
+			});
+		}
+
 		await updateConfig(localStorage.token, config).catch((error) => {
 			toast.error(error);
 			loading = false;
@@ -111,6 +121,7 @@
 			return null;
 		});
 
+		getModels();
 		dispatch('save');
 		loading = false;
 	};
@@ -130,6 +141,24 @@
 				getModels();
 			}
 
+			if (config.comfyui.COMFYUI_WORKFLOW) {
+				config.comfyui.COMFYUI_WORKFLOW = JSON.stringify(
+					JSON.parse(config.comfyui.COMFYUI_WORKFLOW),
+					null,
+					2
+				);
+			}
+
+			if ((config?.comfyui?.COMFYUI_WORKFLOW_NODES ?? []).length >= 5) {
+				workflowNodes = config.comfyui.COMFYUI_WORKFLOW_NODES.map((node) => {
+					return {
+						type: node.type,
+						key: node.key,
+						node_ids: node.node_ids.join(',')
+					};
+				});
+			}
+
 			const imageConfigRes = await getImageGenerationConfig(localStorage.token).catch((error) => {
 				toast.error(error);
 				return null;
@@ -321,7 +350,8 @@
 							<textarea
 								class="w-full rounded-lg mb-1 py-2 px-4 text-xs bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none disabled:text-gray-600 resize-none"
 								rows="10"
-								value={JSON.stringify(JSON.parse(config.comfyui.COMFYUI_WORKFLOW), null, 2)}
+								bind:value={config.comfyui.COMFYUI_WORKFLOW}
+								required
 							/>
 						{/if}
 
@@ -338,8 +368,6 @@
 
 										reader.onload = (e) => {
 											config.comfyui.COMFYUI_WORKFLOW = e.target.result;
-											updateConfigHandler();
-
 											e.target.value = null;
 										};
 
@@ -384,19 +412,21 @@
 													class="py-1 px-3 w-24 text-xs text-center bg-transparent outline-none border-r dark:border-gray-850"
 													placeholder="Key"
 													bind:value={node.key}
+													required
 												/>
 											</Tooltip>
 										</div>
 
 										<div class="w-full">
 											<Tooltip
-												content="Comma separated Node Ids (e.g. 1,2,3)"
+												content="Comma separated Node Ids (e.g. 1 or 1,2)"
 												placement="top-start"
 											>
 												<input
 													class="w-full py-1 px-4 rounded-r-lg text-xs bg-transparent outline-none"
 													placeholder="Node Ids"
 													bind:value={node.node_ids}
+													required
 												/>
 											</Tooltip>
 										</div>