浏览代码

Allow configuration of steps, default to a1111 default

Justin Firth 1 年之前
父节点
当前提交
3fa79e59bf
共有 3 个文件被更改,包括 82 次插入1 次删除
  1. 29 0
      backend/apps/images/main.py
  2. 33 0
      src/lib/apis/images/index.ts
  3. 20 1
      src/lib/components/chat/Settings/Images.svelte

+ 29 - 0
backend/apps/images/main.py

@@ -100,6 +100,32 @@ async def update_image_size(
             status_code=400,
             status_code=400,
             detail=ERROR_MESSAGES.INCORRECT_FORMAT("  (e.g., 512x512)."),
             detail=ERROR_MESSAGES.INCORRECT_FORMAT("  (e.g., 512x512)."),
         )
         )
+    
+
+class ImageStepsUpdateForm(BaseModel):
+    steps: int
+
+
+@app.get("/steps")
+async def get_image_size(user=Depends(get_admin_user)):
+    return {"IMAGE_STEPS": app.state.IMAGE_STEPS}
+
+
+@app.post("/steps/update")
+async def update_image_size(
+    form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
+):
+    if form_data.steps >= 0:
+        app.state.IMAGE_STEPS = form_data.steps
+        return {
+            "IMAGE_STEPS": app.state.IMAGE_STEPS,
+            "status": True,
+        }
+    else:
+        raise HTTPException(
+            status_code=400,
+            detail=ERROR_MESSAGES.INCORRECT_FORMAT("  (e.g., 50)."),
+        )
 
 
 
 
 @app.get("/models")
 @app.get("/models")
@@ -177,6 +203,9 @@ def generate_image(
             "height": height,
             "height": height,
         }
         }
 
 
+        if app.state.IMAGE_STEPS != None:
+            data["steps"] = app.state.IMAGE_STEPS
+
         if form_data.negative_prompt != None:
         if form_data.negative_prompt != None:
             data["negative_prompt"] = form_data.negative_prompt
             data["negative_prompt"] = form_data.negative_prompt
 
 

+ 33 - 0
src/lib/apis/images/index.ts

@@ -198,6 +198,39 @@ export const updateImageSize = async (token: string = '', size: string) => {
 	return res.IMAGE_SIZE;
 	return res.IMAGE_SIZE;
 };
 };
 
 
+export const updateImageSteps = async (token: string = '', steps: number) => {
+	let error = null;
+
+	const res = await fetch(`${IMAGES_API_BASE_URL}/steps/update`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			...(token && { authorization: `Bearer ${token}` })
+		},
+		body: JSON.stringify({ steps })
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			if ('detail' in err) {
+				error = err.detail;
+			} else {
+				error = 'Server connection failed';
+			}
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res.IMAGE_STEPS;
+};
+
 export const getDiffusionModels = async (token: string = '') => {
 export const getDiffusionModels = async (token: string = '') => {
 	let error = null;
 	let error = null;
 
 

+ 20 - 1
src/lib/components/chat/Settings/Images.svelte

@@ -12,7 +12,8 @@
 		toggleImageGenerationEnabledStatus,
 		toggleImageGenerationEnabledStatus,
 		updateAUTOMATIC1111Url,
 		updateAUTOMATIC1111Url,
 		updateDefaultDiffusionModel,
 		updateDefaultDiffusionModel,
-		updateImageSize
+		updateImageSize,
+		updateImageSteps
 	} from '$lib/apis/images';
 	} from '$lib/apis/images';
 	import { getBackendConfig } from '$lib/apis';
 	import { getBackendConfig } from '$lib/apis';
 	const dispatch = createEventDispatcher();
 	const dispatch = createEventDispatcher();
@@ -28,6 +29,7 @@
 	let models = [];
 	let models = [];
 
 
 	let imageSize = '';
 	let imageSize = '';
+	let steps = 50;
 
 
 	const getModels = async () => {
 	const getModels = async () => {
 		models = await getDiffusionModels(localStorage.token).catch((error) => {
 		models = await getDiffusionModels(localStorage.token).catch((error) => {
@@ -98,6 +100,10 @@
 			toast.error(error);
 			toast.error(error);
 			return null;
 			return null;
 		});
 		});
+		await updateImageSteps(localStorage.token, steps).catch((error) => {
+			toast.error(error);
+			return null;
+		});
 
 
 		dispatch('save');
 		dispatch('save');
 		loading = false;
 		loading = false;
@@ -210,6 +216,19 @@
 					</div>
 					</div>
 				</div>
 				</div>
 			</div>
 			</div>
+
+			<div>
+				<div class=" mb-2.5 text-sm font-medium">Set Steps</div>
+				<div class="flex w-full">
+					<div class="flex-1 mr-2">
+						<input
+							class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
+							placeholder="Enter Number of Steps (e.g. 50)"
+							bind:value={steps}
+						/>
+					</div>
+				</div>
+			</div>
 		{/if}
 		{/if}
 	</div>
 	</div>