Explorar o código

feat: image size param

Timothy J. Baek hai 1 ano
pai
achega
7ec4c07bf9

+ 30 - 2
backend/apps/images/main.py

@@ -1,4 +1,4 @@
-import os
+import re
 import requests
 from fastapi import (
     FastAPI,
@@ -34,6 +34,7 @@ app.add_middleware(
 
 app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
 app.state.ENABLED = app.state.AUTOMATIC1111_BASE_URL != ""
+app.state.IMAGE_SIZE = "512x512"
 
 
 @app.get("/enabled", response_model=bool)
@@ -74,6 +75,33 @@ async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_use
     }
 
 
+class ImageSizeUpdateForm(BaseModel):
+    size: str
+
+
+@app.get("/size")
+async def get_image_size(user=Depends(get_admin_user)):
+    return {"IMAGE_SIZE": app.state.IMAGE_SIZE}
+
+
+@app.post("/size/update")
+async def update_image_size(
+    form_data: ImageSizeUpdateForm, user=Depends(get_admin_user)
+):
+    pattern = r"^\d+x\d+$"  # Regular expression pattern
+    if re.match(pattern, form_data.size):
+        app.state.IMAGE_SIZE = form_data.size
+        return {
+            "IMAGE_SIZE": app.state.IMAGE_SIZE,
+            "status": True,
+        }
+    else:
+        raise HTTPException(
+            status_code=400,
+            detail=ERROR_MESSAGES.INCORRECT_FORMAT("  (e.g., 512x512)."),
+        )
+
+
 @app.get("/models")
 def get_models(user=Depends(get_current_user)):
     try:
@@ -140,7 +168,7 @@ def generate_image(
         if form_data.model:
             set_model_handler(form_data.model)
 
-        width, height = tuple(map(int, form_data.size.split("x")))
+        width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
 
         data = {
             "prompt": form_data.prompt,

+ 3 - 0
backend/constants.py

@@ -44,3 +44,6 @@ class ERROR_MESSAGES(str, Enum):
     MALICIOUS = "Unusual activities detected, please try again in a few minutes."
 
     PANDOC_NOT_INSTALLED = "Pandoc is not installed on the server. Please contact your administrator for assistance."
+    INCORRECT_FORMAT = (
+        lambda err="": f"Invalid format. Please use the correct format{err if err else ''}"
+    )

+ 67 - 0
src/lib/apis/images/index.ts

@@ -131,6 +131,73 @@ export const updateAUTOMATIC1111Url = async (token: string = '', url: string) =>
 	return res.AUTOMATIC1111_BASE_URL;
 };
 
+export const getImageSize = async (token: string = '') => {
+	let error = null;
+
+	const res = await fetch(`${IMAGES_API_BASE_URL}/size`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			...(token && { authorization: `Bearer ${token}` })
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			if ('detail' in err) {
+				error = err.detail;
+			} else {
+				error = 'Server connection failed';
+			}
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res.IMAGE_SIZE;
+};
+
+export const updateImageSize = async (token: string = '', size: string) => {
+	let error = null;
+
+	const res = await fetch(`${IMAGES_API_BASE_URL}/size/update`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			...(token && { authorization: `Bearer ${token}` })
+		},
+		body: JSON.stringify({
+			size: size
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			if ('detail' in err) {
+				error = err.detail;
+			} else {
+				error = 'Server connection failed';
+			}
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res.IMAGE_SIZE;
+};
+
 export const getDiffusionModels = async (token: string = '') => {
 	let error = null;
 

+ 25 - 4
src/lib/components/chat/Settings/Images.svelte

@@ -8,9 +8,11 @@
 		getDefaultDiffusionModel,
 		getDiffusionModels,
 		getImageGenerationEnabledStatus,
+		getImageSize,
 		toggleImageGenerationEnabledStatus,
 		updateAUTOMATIC1111Url,
-		updateDefaultDiffusionModel
+		updateDefaultDiffusionModel,
+		updateImageSize
 	} from '$lib/apis/images';
 	import { getBackendConfig } from '$lib/apis';
 	const dispatch = createEventDispatcher();
@@ -25,6 +27,8 @@
 	let selectedModel = '';
 	let models = [];
 
+	let imageSize = '';
+
 	const getModels = async () => {
 		models = await getDiffusionModels(localStorage.token).catch((error) => {
 			toast.error(error);
@@ -53,7 +57,6 @@
 			AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token);
 		}
 	};
-
 	const toggleImageGeneration = async () => {
 		if (AUTOMATIC1111_BASE_URL) {
 			enableImageGeneration = await toggleImageGenerationEnabledStatus(localStorage.token).catch(
@@ -79,6 +82,7 @@
 			AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token);
 
 			if (enableImageGeneration && AUTOMATIC1111_BASE_URL) {
+				imageSize = await getImageSize(localStorage.token);
 				getModels();
 			}
 		}
@@ -89,13 +93,17 @@
 	class="flex flex-col h-full justify-between space-y-3 text-sm"
 	on:submit|preventDefault={async () => {
 		loading = true;
-		const res = await updateDefaultDiffusionModel(localStorage.token, selectedModel);
+		await updateDefaultDiffusionModel(localStorage.token, selectedModel);
+		await updateImageSize(localStorage.token, imageSize).catch((error) => {
+			toast.error(error);
+			return null;
+		});
 
 		dispatch('save');
 		loading = false;
 	}}
 >
-	<div class=" space-y-3 pr-1.5 overflow-y-scroll max-h-80">
+	<div class=" space-y-3 pr-1.5 overflow-y-scroll max-h-[21rem]">
 		<div>
 			<div class=" mb-1 text-sm font-medium">Image Settings</div>
 
@@ -168,6 +176,19 @@
 		{#if enableImageGeneration}
 			<hr class=" dark:border-gray-700" />
 
+			<div>
+				<div class=" mb-2.5 text-sm font-medium">Set Image Size</div>
+				<div class="flex w-full">
+					<div class="flex-1 mr-2">
+						<input
+							class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
+							placeholder="Enter Image Size (e.g. 512x512)"
+							bind:value={imageSize}
+						/>
+					</div>
+				</div>
+			</div>
+
 			<div>
 				<div class=" mb-2.5 text-sm font-medium">Set default model</div>
 				<div class="flex w-full">