소스 검색

Merge pull request #888 from jmfirth/feature/img-gen-steps-setting

Add steps setting for image generation
Timothy Jaeryang Baek 1 년 전
부모
커밋
92ec0d90e1
4개의 변경된 파일119개의 추가작업 그리고 2개의 파일을 삭제
  1. 1 0
      CHANGELOG.md
  2. 30 0
      backend/apps/images/main.py
  3. 65 0
      src/lib/apis/images/index.ts
  4. 23 2
      src/lib/components/chat/Settings/Images.svelte

+ 1 - 0
CHANGELOG.md

@@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
 
 - **Built-in LiteLLM Proxy**: Open WebUI now ships with LiteLLM Proxy.
 - **Image Generation Enhancements**: Advanced Settings + Image Preview Feature.
+  - Allows setting number of steps for image generation; defaults to a1111 default value.
 
 ### Fixed
 

+ 30 - 0
backend/apps/images/main.py

@@ -35,6 +35,7 @@ app.add_middleware(
 app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
 app.state.ENABLED = app.state.AUTOMATIC1111_BASE_URL != ""
 app.state.IMAGE_SIZE = "512x512"
+app.state.IMAGE_STEPS = 50
 
 
 @app.get("/enabled", response_model=bool)
@@ -102,6 +103,32 @@ async def update_image_size(
         )
 
 
+class ImageStepsUpdateForm(BaseModel):
+    steps: int
+
+
+@app.get("/steps")
+async def get_image_size(user=Depends(get_admin_user)):
+    return {"IMAGE_STEPS": app.state.IMAGE_STEPS}
+
+
+@app.post("/steps/update")
+async def update_image_size(
+    form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
+):
+    if form_data.steps >= 0:
+        app.state.IMAGE_STEPS = form_data.steps
+        return {
+            "IMAGE_STEPS": app.state.IMAGE_STEPS,
+            "status": True,
+        }
+    else:
+        raise HTTPException(
+            status_code=400,
+            detail=ERROR_MESSAGES.INCORRECT_FORMAT("  (e.g., 50)."),
+        )
+
+
 @app.get("/models")
 def get_models(user=Depends(get_current_user)):
     try:
@@ -179,6 +206,9 @@ def generate_image(
             "height": height,
         }
 
+        if app.state.IMAGE_STEPS != None:
+            data["steps"] = app.state.IMAGE_STEPS
+
         if form_data.negative_prompt != None:
             data["negative_prompt"] = form_data.negative_prompt
 

+ 65 - 0
src/lib/apis/images/index.ts

@@ -198,6 +198,71 @@ export const updateImageSize = async (token: string = '', size: string) => {
 	return res.IMAGE_SIZE;
 };
 
+export const getImageSteps = async (token: string = '') => {
+	let error = null;
+
+	const res = await fetch(`${IMAGES_API_BASE_URL}/steps`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			...(token && { authorization: `Bearer ${token}` })
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			if ('detail' in err) {
+				error = err.detail;
+			} else {
+				error = 'Server connection failed';
+			}
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res.IMAGE_STEPS;
+};
+
+export const updateImageSteps = async (token: string = '', steps: number) => {
+	let error = null;
+
+	const res = await fetch(`${IMAGES_API_BASE_URL}/steps/update`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			...(token && { authorization: `Bearer ${token}` })
+		},
+		body: JSON.stringify({ steps })
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			if ('detail' in err) {
+				error = err.detail;
+			} else {
+				error = 'Server connection failed';
+			}
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res.IMAGE_STEPS;
+};
+
 export const getDiffusionModels = async (token: string = '') => {
 	let error = null;
 

+ 23 - 2
src/lib/components/chat/Settings/Images.svelte

@@ -12,7 +12,9 @@
 		toggleImageGenerationEnabledStatus,
 		updateAUTOMATIC1111Url,
 		updateDefaultDiffusionModel,
-		updateImageSize
+		updateImageSize,
+		getImageSteps,
+		updateImageSteps
 	} from '$lib/apis/images';
 	import { getBackendConfig } from '$lib/apis';
 	const dispatch = createEventDispatcher();
@@ -21,13 +23,14 @@
 
 	let loading = false;
 
-	let enableImageGeneration = true;
+	let enableImageGeneration = false;
 	let AUTOMATIC1111_BASE_URL = '';
 
 	let selectedModel = '';
 	let models = [];
 
 	let imageSize = '';
+	let steps = 50;
 
 	const getModels = async () => {
 		models = await getDiffusionModels(localStorage.token).catch((error) => {
@@ -85,6 +88,7 @@
 
 			if (enableImageGeneration && AUTOMATIC1111_BASE_URL) {
 				imageSize = await getImageSize(localStorage.token);
+				steps = await getImageSteps(localStorage.token);
 				getModels();
 			}
 		}
@@ -100,6 +104,10 @@
 			toast.error(error);
 			return null;
 		});
+		await updateImageSteps(localStorage.token, steps).catch((error) => {
+			toast.error(error);
+			return null;
+		});
 
 		dispatch('save');
 		loading = false;
@@ -212,6 +220,19 @@
 					</div>
 				</div>
 			</div>
+
+			<div>
+				<div class=" mb-2.5 text-sm font-medium">Set Steps</div>
+				<div class="flex w-full">
+					<div class="flex-1 mr-2">
+						<input
+							class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
+							placeholder="Enter Number of Steps (e.g. 50)"
+							bind:value={steps}
+						/>
+					</div>
+				</div>
+			</div>
 		{/if}
 	</div>