소스 검색

feat: comfyui support

Timothy J. Baek 1 년 전
부모
커밋
862c96fcef
4개의 변경된 파일120개의 추가작업 그리고 32개의 파일을 삭제
  1. 34 9
      backend/apps/images/main.py
  2. 1 0
      backend/config.py
  3. 5 5
      src/lib/apis/images/index.ts
  4. 80 18
      src/lib/components/chat/Settings/Images.svelte

+ 34 - 9
backend/apps/images/main.py

@@ -26,7 +26,7 @@ import uuid
 import base64
 import json
 
-from config import CACHE_DIR, AUTOMATIC1111_BASE_URL
+from config import CACHE_DIR, AUTOMATIC1111_BASE_URL, COMFYUI_BASE_URL
 
 
 IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
@@ -49,6 +49,8 @@ app.state.MODEL = ""
 
 
 app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
+app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
+
 
 app.state.IMAGE_SIZE = "512x512"
 app.state.IMAGE_STEPS = 50
@@ -71,32 +73,43 @@ async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user
     return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
 
 
-class UrlUpdateForm(BaseModel):
-    url: str
+class EngineUrlUpdateForm(BaseModel):
+    AUTOMATIC1111_BASE_URL: Optional[str] = None
+    COMFYUI_BASE_URL: Optional[str] = None
 
 
 @app.get("/url")
-async def get_automatic1111_url(user=Depends(get_admin_user)):
-    return {"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL}
+async def get_engine_url(user=Depends(get_admin_user)):
+    return {
+        "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
+        "COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL,
+    }
 
 
 @app.post("/url/update")
-async def update_automatic1111_url(
-    form_data: UrlUpdateForm, user=Depends(get_admin_user)
+async def update_engine_url(
+    form_data: EngineUrlUpdateForm, user=Depends(get_admin_user)
 ):
 
-    if form_data.url == "":
+    if form_data.AUTOMATIC1111_BASE_URL == None:
         app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
     else:
-        url = form_data.url.strip("/")
+        url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
         try:
             r = requests.head(url)
             app.state.AUTOMATIC1111_BASE_URL = url
         except Exception as e:
             raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
 
+    if form_data.COMFYUI_BASE_URL == None:
+        app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
+    else:
+        url = form_data.COMFYUI_BASE_URL.strip("/")
+        app.state.COMFYUI_BASE_URL = url
+
     return {
         "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
+        "COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL,
         "status": True,
     }
 
@@ -186,6 +199,18 @@ def get_models(user=Depends(get_current_user)):
                 {"id": "dall-e-2", "name": "DALL·E 2"},
                 {"id": "dall-e-3", "name": "DALL·E 3"},
             ]
+        elif app.state.ENGINE == "comfyui":
+
+            r = requests.get(url=f"{app.state.COMFYUI_BASE_URL}/object_info")
+            info = r.json()
+
+            return list(
+                map(
+                    lambda model: {"id": model, "name": model},
+                    info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0],
+                )
+            )
+
         else:
             r = requests.get(
                 url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"

+ 1 - 0
backend/config.py

@@ -376,3 +376,4 @@ WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models"
 ####################################
 
 AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "")
+COMFYUI_BASE_URL = os.getenv("COMFYUI_BASE_URL", "")

+ 5 - 5
src/lib/apis/images/index.ts

@@ -139,7 +139,7 @@ export const updateOpenAIKey = async (token: string = '', key: string) => {
 	return res.OPENAI_API_KEY;
 };
 
-export const getAUTOMATIC1111Url = async (token: string = '') => {
+export const getImageGenerationEngineUrls = async (token: string = '') => {
 	let error = null;
 
 	const res = await fetch(`${IMAGES_API_BASE_URL}/url`, {
@@ -168,10 +168,10 @@ export const getAUTOMATIC1111Url = async (token: string = '') => {
 		throw error;
 	}
 
-	return res.AUTOMATIC1111_BASE_URL;
+	return res;
 };
 
-export const updateAUTOMATIC1111Url = async (token: string = '', url: string) => {
+export const updateImageGenerationEngineUrls = async (token: string = '', urls: object = {}) => {
 	let error = null;
 
 	const res = await fetch(`${IMAGES_API_BASE_URL}/url/update`, {
@@ -182,7 +182,7 @@ export const updateAUTOMATIC1111Url = async (token: string = '', url: string) =>
 			...(token && { authorization: `Bearer ${token}` })
 		},
 		body: JSON.stringify({
-			url: url
+			...urls
 		})
 	})
 		.then(async (res) => {
@@ -203,7 +203,7 @@ export const updateAUTOMATIC1111Url = async (token: string = '', url: string) =>
 		throw error;
 	}
 
-	return res.AUTOMATIC1111_BASE_URL;
+	return res;
 };
 
 export const getImageSize = async (token: string = '') => {

+ 80 - 18
src/lib/components/chat/Settings/Images.svelte

@@ -4,14 +4,14 @@
 	import { createEventDispatcher, onMount, getContext } from 'svelte';
 	import { config, user } from '$lib/stores';
 	import {
-		getAUTOMATIC1111Url,
 		getImageGenerationModels,
 		getDefaultImageGenerationModel,
 		updateDefaultImageGenerationModel,
 		getImageSize,
 		getImageGenerationConfig,
 		updateImageGenerationConfig,
-		updateAUTOMATIC1111Url,
+		getImageGenerationEngineUrls,
+		updateImageGenerationEngineUrls,
 		updateImageSize,
 		getImageSteps,
 		updateImageSteps,
@@ -31,6 +31,8 @@
 	let enableImageGeneration = false;
 
 	let AUTOMATIC1111_BASE_URL = '';
+	let COMFYUI_BASE_URL = '';
+
 	let OPENAI_API_KEY = '';
 
 	let selectedModel = '';
@@ -49,24 +51,47 @@
 		});
 	};
 
-	const updateAUTOMATIC1111UrlHandler = async () => {
-		const res = await updateAUTOMATIC1111Url(localStorage.token, AUTOMATIC1111_BASE_URL).catch(
-			(error) => {
+	const updateUrlHandler = async () => {
+		if (imageGenerationEngine === 'comfyui') {
+			const res = await updateImageGenerationEngineUrls(localStorage.token, {
+				COMFYUI_BASE_URL: COMFYUI_BASE_URL
+			}).catch((error) => {
 				toast.error(error);
+
+				console.log(error);
 				return null;
-			}
-		);
+			});
 
-		if (res) {
-			AUTOMATIC1111_BASE_URL = res;
+			if (res) {
+				COMFYUI_BASE_URL = res.COMFYUI_BASE_URL;
 
-			await getModels();
+				await getModels();
 
-			if (models) {
-				toast.success($i18n.t('Server connection verified'));
+				if (models) {
+					toast.success($i18n.t('Server connection verified'));
+				}
+			} else {
+				({ COMFYUI_BASE_URL } = await getImageGenerationEngineUrls(localStorage.token));
 			}
 		} else {
-			AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token);
+			const res = await updateImageGenerationEngineUrls(localStorage.token, {
+				AUTOMATIC1111_BASE_URL: AUTOMATIC1111_BASE_URL
+			}).catch((error) => {
+				toast.error(error);
+				return null;
+			});
+
+			if (res) {
+				AUTOMATIC1111_BASE_URL = res.AUTOMATIC1111_BASE_URL;
+
+				await getModels();
+
+				if (models) {
+					toast.success($i18n.t('Server connection verified'));
+				}
+			} else {
+				({ AUTOMATIC1111_BASE_URL } = await getImageGenerationEngineUrls(localStorage.token));
+			}
 		}
 	};
 	const updateImageGeneration = async () => {
@@ -101,7 +126,11 @@
 				imageGenerationEngine = res.engine;
 				enableImageGeneration = res.enabled;
 			}
-			AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token);
+			const URLS = await getImageGenerationEngineUrls(localStorage.token);
+
+			AUTOMATIC1111_BASE_URL = URLS.AUTOMATIC1111_BASE_URL;
+			COMFYUI_BASE_URL = URLS.COMFYUI_BASE_URL;
+
 			OPENAI_API_KEY = await getOpenAIKey(localStorage.token);
 
 			imageSize = await getImageSize(localStorage.token);
@@ -154,6 +183,7 @@
 						}}
 					>
 						<option value="">{$i18n.t('Default (Automatic1111)')}</option>
+						<option value="comfyui">{$i18n.t('ComfyUI')}</option>
 						<option value="openai">{$i18n.t('Open AI (Dall-E)')}</option>
 					</select>
 				</div>
@@ -171,6 +201,9 @@
 							if (imageGenerationEngine === '' && AUTOMATIC1111_BASE_URL === '') {
 								toast.error($i18n.t('AUTOMATIC1111 Base URL is required.'));
 								enableImageGeneration = false;
+							} else if (imageGenerationEngine === 'comfyui' && COMFYUI_BASE_URL === '') {
+								toast.error($i18n.t('ComfyUI Base URL is required.'));
+								enableImageGeneration = false;
 							} else if (imageGenerationEngine === 'openai' && OPENAI_API_KEY === '') {
 								toast.error($i18n.t('OpenAI API Key is required.'));
 								enableImageGeneration = false;
@@ -204,12 +237,10 @@
 					/>
 				</div>
 				<button
-					class="px-3 bg-gray-200 hover:bg-gray-300 dark:bg-gray-600 dark:hover:bg-gray-700 rounded-lg transition"
+					class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
 					type="button"
 					on:click={() => {
-						// updateOllamaAPIUrlHandler();
-
-						updateAUTOMATIC1111UrlHandler();
+						updateUrlHandler();
 					}}
 				>
 					<svg
@@ -237,6 +268,37 @@
 					{$i18n.t('(e.g. `sh webui.sh --api`)')}
 				</a>
 			</div>
+		{:else if imageGenerationEngine === 'comfyui'}
+			<div class=" mb-2.5 text-sm font-medium">{$i18n.t('ComfyUI Base URL')}</div>
+			<div class="flex w-full">
+				<div class="flex-1 mr-2">
+					<input
+						class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+						placeholder={$i18n.t('Enter URL (e.g. http://127.0.0.1:7860/)')}
+						bind:value={COMFYUI_BASE_URL}
+					/>
+				</div>
+				<button
+					class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
+					type="button"
+					on:click={() => {
+						updateUrlHandler();
+					}}
+				>
+					<svg
+						xmlns="http://www.w3.org/2000/svg"
+						viewBox="0 0 20 20"
+						fill="currentColor"
+						class="w-4 h-4"
+					>
+						<path
+							fill-rule="evenodd"
+							d="M15.312 11.424a5.5 5.5 0 01-9.201 2.466l-.312-.311h2.433a.75.75 0 000-1.5H3.989a.75.75 0 00-.75.75v4.242a.75.75 0 001.5 0v-2.43l.31.31a7 7 0 0011.712-3.138.75.75 0 00-1.449-.39zm1.23-3.723a.75.75 0 00.219-.53V2.929a.75.75 0 00-1.5 0V5.36l-.31-.31A7 7 0 003.239 8.188a.75.75 0 101.448.389A5.5 5.5 0 0113.89 6.11l.311.31h-2.432a.75.75 0 000 1.5h4.243a.75.75 0 00.53-.219z"
+							clip-rule="evenodd"
+						/>
+					</svg>
+				</button>
+			</div>
 		{:else if imageGenerationEngine === 'openai'}
 			<div class=" mb-2.5 text-sm font-medium">{$i18n.t('OpenAI API Key')}</div>
 			<div class="flex w-full">