浏览代码

refac: image generation

Timothy J. Baek 1 年之前
父节点
当前提交
dd3a4b3889
共有 2 个文件被更改,包括 127 次插入83 次删除
  1. 17 9
      src/lib/apis/images/index.ts
  2. 110 74
      src/lib/components/chat/Settings/Images.svelte

+ 17 - 9
src/lib/apis/images/index.ts

@@ -1,9 +1,9 @@
 import { IMAGES_API_BASE_URL } from '$lib/constants';
 
-export const getImageGenerationEnabledStatus = async (token: string = '') => {
+export const getImageGenerationConfig = async (token: string = '') => {
 	let error = null;
 
-	const res = await fetch(`${IMAGES_API_BASE_URL}/enabled`, {
+	const res = await fetch(`${IMAGES_API_BASE_URL}/config`, {
 		method: 'GET',
 		headers: {
 			Accept: 'application/json',
@@ -32,16 +32,24 @@ export const getImageGenerationEnabledStatus = async (token: string = '') => {
 	return res;
 };
 
-export const toggleImageGenerationEnabledStatus = async (token: string = '') => {
+export const updateImageGenerationConfig = async (
+	token: string = '',
+	engine: string,
+	enabled: boolean
+) => {
 	let error = null;
 
-	const res = await fetch(`${IMAGES_API_BASE_URL}/enabled/toggle`, {
-		method: 'GET',
+	const res = await fetch(`${IMAGES_API_BASE_URL}/config/update`, {
+		method: 'POST',
 		headers: {
 			Accept: 'application/json',
 			'Content-Type': 'application/json',
 			...(token && { authorization: `Bearer ${token}` })
-		}
+		},
+		body: JSON.stringify({
+			engine,
+			enabled
+		})
 	})
 		.then(async (res) => {
 			if (!res.ok) throw await res.json();
@@ -263,7 +271,7 @@ export const updateImageSteps = async (token: string = '', steps: number) => {
 	return res.IMAGE_STEPS;
 };
 
-export const getDiffusionModels = async (token: string = '') => {
+export const getImageGenerationModels = async (token: string = '') => {
 	let error = null;
 
 	const res = await fetch(`${IMAGES_API_BASE_URL}/models`, {
@@ -295,7 +303,7 @@ export const getDiffusionModels = async (token: string = '') => {
 	return res;
 };
 
-export const getDefaultDiffusionModel = async (token: string = '') => {
+export const getDefaultImageGenerationModel = async (token: string = '') => {
 	let error = null;
 
 	const res = await fetch(`${IMAGES_API_BASE_URL}/models/default`, {
@@ -327,7 +335,7 @@ export const getDefaultDiffusionModel = async (token: string = '') => {
 	return res.model;
 };
 
-export const updateDefaultDiffusionModel = async (token: string = '', model: string) => {
+export const updateDefaultImageGenerationModel = async (token: string = '', model: string) => {
 	let error = null;
 
 	const res = await fetch(`${IMAGES_API_BASE_URL}/models/default/update`, {

+ 110 - 74
src/lib/components/chat/Settings/Images.svelte

@@ -5,13 +5,13 @@
 	import { config, user } from '$lib/stores';
 	import {
 		getAUTOMATIC1111Url,
-		getDefaultDiffusionModel,
-		getDiffusionModels,
-		getImageGenerationEnabledStatus,
+		getImageGenerationModels,
+		getDefaultImageGenerationModel,
+		updateDefaultImageGenerationModel,
 		getImageSize,
-		toggleImageGenerationEnabledStatus,
+		getImageGenerationConfig,
+		updateImageGenerationConfig,
 		updateAUTOMATIC1111Url,
-		updateDefaultDiffusionModel,
 		updateImageSize,
 		getImageSteps,
 		updateImageSteps
@@ -23,7 +23,9 @@
 
 	let loading = false;
 
+	let imageGenerationEngine = '';
 	let enableImageGeneration = false;
+
 	let AUTOMATIC1111_BASE_URL = '';
 
 	let selectedModel = '';
@@ -33,11 +35,11 @@
 	let steps = 50;
 
 	const getModels = async () => {
-		models = await getDiffusionModels(localStorage.token).catch((error) => {
+		models = await getImageGenerationModels(localStorage.token).catch((error) => {
 			toast.error(error);
 			return null;
 		});
-		selectedModel = await getDefaultDiffusionModel(localStorage.token).catch((error) => {
+		selectedModel = await getDefaultImageGenerationModel(localStorage.token).catch((error) => {
 			return '';
 		});
 	};
@@ -62,33 +64,44 @@
 			AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token);
 		}
 	};
-	const toggleImageGeneration = async () => {
-		if (AUTOMATIC1111_BASE_URL) {
-			enableImageGeneration = await toggleImageGenerationEnabledStatus(localStorage.token).catch(
-				(error) => {
-					toast.error(error);
-					return false;
-				}
-			);
+	const updateImageGeneration = async () => {
+		const res = await updateImageGenerationConfig(
+			localStorage.token,
+			imageGenerationEngine,
+			enableImageGeneration
+		).catch((error) => {
+			toast.error(error);
+			return null;
+		});
 
-			if (enableImageGeneration) {
-				config.set(await getBackendConfig(localStorage.token));
-				getModels();
-			}
-		} else {
-			enableImageGeneration = false;
-			toast.error('AUTOMATIC1111_BASE_URL not provided');
+		if (res) {
+			imageGenerationEngine = res.engine;
+			enableImageGeneration = res.enabled;
+		}
+
+		if (enableImageGeneration) {
+			config.set(await getBackendConfig(localStorage.token));
+			getModels();
 		}
 	};
 
 	onMount(async () => {
 		if ($user.role === 'admin') {
-			enableImageGeneration = await getImageGenerationEnabledStatus(localStorage.token);
+			const res = await getImageGenerationConfig(localStorage.token).catch((error) => {
+				toast.error(error);
+				return null;
+			});
+
+			if (res) {
+				imageGenerationEngine = res.engine;
+				enableImageGeneration = res.enabled;
+			}
 			AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token);
 
-			if (enableImageGeneration && AUTOMATIC1111_BASE_URL) {
-				imageSize = await getImageSize(localStorage.token);
-				steps = await getImageSteps(localStorage.token);
+			imageSize = await getImageSize(localStorage.token);
+			steps = await getImageSteps(localStorage.token);
+
+			if (enableImageGeneration) {
 				getModels();
 			}
 		}
@@ -99,7 +112,7 @@
 	class="flex flex-col h-full justify-between space-y-3 text-sm"
 	on:submit|preventDefault={async () => {
 		loading = true;
-		await updateDefaultDiffusionModel(localStorage.token, selectedModel);
+		await updateDefaultImageGenerationModel(localStorage.token, selectedModel);
 		await updateImageSize(localStorage.token, imageSize).catch((error) => {
 			toast.error(error);
 			return null;
@@ -117,6 +130,23 @@
 		<div>
 			<div class=" mb-1 text-sm font-medium">Image Settings</div>
 
+			<div class=" py-0.5 flex w-full justify-between">
+				<div class=" self-center text-xs font-medium">Image Generation Engine</div>
+				<div class="flex items-center relative">
+					<select
+						class="w-fit pr-8 rounded px-2 p-1 text-xs bg-transparent outline-none text-right"
+						bind:value={imageGenerationEngine}
+						placeholder="Select a mode"
+						on:change={async () => {
+							await updateImageGeneration();
+						}}
+					>
+						<option value="">Default (Automatic1111)</option>
+						<option value="openai">Open AI (Dall-E)</option>
+					</select>
+				</div>
+			</div>
+
 			<div>
 				<div class=" py-0.5 flex w-full justify-between">
 					<div class=" self-center text-xs font-medium">Image Generation (Experimental)</div>
@@ -124,7 +154,12 @@
 					<button
 						class="p-1 px-3 text-xs flex rounded transition"
 						on:click={() => {
-							toggleImageGeneration();
+							if (imageGenerationEngine === '' && AUTOMATIC1111_BASE_URL === '') {
+								toast.error('AUTOMATIC1111 Base URL is required.');
+							} else {
+								enableImageGeneration = !enableImageGeneration;
+								updateImageGeneration();
+							}
 						}}
 						type="button"
 					>
@@ -137,51 +172,54 @@
 				</div>
 			</div>
 		</div>
-		<hr class=" dark:border-gray-700" />
-
-		<div class=" mb-2.5 text-sm font-medium">AUTOMATIC1111 Base URL</div>
-		<div class="flex w-full">
-			<div class="flex-1 mr-2">
-				<input
-					class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
-					placeholder="Enter URL (e.g. http://127.0.0.1:7860/)"
-					bind:value={AUTOMATIC1111_BASE_URL}
-				/>
-			</div>
-			<button
-				class="px-3 bg-gray-200 hover:bg-gray-300 dark:bg-gray-600 dark:hover:bg-gray-700 rounded transition"
-				type="button"
-				on:click={() => {
-					// updateOllamaAPIUrlHandler();
-
-					updateAUTOMATIC1111UrlHandler();
-				}}
-			>
-				<svg
-					xmlns="http://www.w3.org/2000/svg"
-					viewBox="0 0 20 20"
-					fill="currentColor"
-					class="w-4 h-4"
-				>
-					<path
-						fill-rule="evenodd"
-						d="M15.312 11.424a5.5 5.5 0 01-9.201 2.466l-.312-.311h2.433a.75.75 0 000-1.5H3.989a.75.75 0 00-.75.75v4.242a.75.75 0 001.5 0v-2.43l.31.31a7 7 0 0011.712-3.138.75.75 0 00-1.449-.39zm1.23-3.723a.75.75 0 00.219-.53V2.929a.75.75 0 00-1.5 0V5.36l-.31-.31A7 7 0 003.239 8.188a.75.75 0 101.448.389A5.5 5.5 0 0113.89 6.11l.311.31h-2.432a.75.75 0 000 1.5h4.243a.75.75 0 00.53-.219z"
-						clip-rule="evenodd"
+
+		{#if imageGenerationEngine === ''}
+			<hr class=" dark:border-gray-700" />
+
+			<div class=" mb-2.5 text-sm font-medium">AUTOMATIC1111 Base URL</div>
+			<div class="flex w-full">
+				<div class="flex-1 mr-2">
+					<input
+						class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
+						placeholder="Enter URL (e.g. http://127.0.0.1:7860/)"
+						bind:value={AUTOMATIC1111_BASE_URL}
 					/>
-				</svg>
-			</button>
-		</div>
+				</div>
+				<button
+					class="px-3 bg-gray-200 hover:bg-gray-300 dark:bg-gray-600 dark:hover:bg-gray-700 rounded transition"
+					type="button"
+					on:click={() => {
+						// updateOllamaAPIUrlHandler();
 
-		<div class="mt-2 text-xs text-gray-400 dark:text-gray-500">
-			Include `--api` flag when running stable-diffusion-webui
-			<a
-				class=" text-gray-300 font-medium"
-				href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/3734"
-				target="_blank"
-			>
-				(e.g. `sh webui.sh --api`)
-			</a>
-		</div>
+						updateAUTOMATIC1111UrlHandler();
+					}}
+				>
+					<svg
+						xmlns="http://www.w3.org/2000/svg"
+						viewBox="0 0 20 20"
+						fill="currentColor"
+						class="w-4 h-4"
+					>
+						<path
+							fill-rule="evenodd"
+							d="M15.312 11.424a5.5 5.5 0 01-9.201 2.466l-.312-.311h2.433a.75.75 0 000-1.5H3.989a.75.75 0 00-.75.75v4.242a.75.75 0 001.5 0v-2.43l.31.31a7 7 0 0011.712-3.138.75.75 0 00-1.449-.39zm1.23-3.723a.75.75 0 00.219-.53V2.929a.75.75 0 00-1.5 0V5.36l-.31-.31A7 7 0 003.239 8.188a.75.75 0 101.448.389A5.5 5.5 0 0113.89 6.11l.311.31h-2.432a.75.75 0 000 1.5h4.243a.75.75 0 00.53-.219z"
+							clip-rule="evenodd"
+						/>
+					</svg>
+				</button>
+			</div>
+
+			<div class="mt-2 text-xs text-gray-400 dark:text-gray-500">
+				Include `--api` flag when running stable-diffusion-webui
+				<a
+					class=" text-gray-300 font-medium"
+					href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/3734"
+					target="_blank"
+				>
+					(e.g. `sh webui.sh --api`)
+				</a>
+			</div>
+		{/if}
 
 		{#if enableImageGeneration}
 			<hr class=" dark:border-gray-700" />
@@ -199,9 +237,7 @@
 								<option value="" disabled selected>Select a model</option>
 							{/if}
 							{#each models ?? [] as model}
-								<option value={model.title} class="bg-gray-100 dark:bg-gray-700"
-									>{model.model_name}</option
-								>
+								<option value={model.id} class="bg-gray-100 dark:bg-gray-700">{model.name}</option>
 							{/each}
 						</select>
 					</div>