Browse Source

feat: sd frontend integration

Timothy J. Baek 1 year ago
parent
commit
cc50cc10e6

+ 2 - 0
example.env → .env.example

@@ -5,6 +5,8 @@ OLLAMA_API_BASE_URL='http://localhost:11434/api'
 OPENAI_API_BASE_URL=''
 OPENAI_API_KEY=''
 
+# AUTOMATIC1111_BASE_URL="http://localhost:7860"
+
 # DO NOT TRACK
 SCARF_NO_ANALYTICS=true
 DO_NOT_TRACK=true

+ 1 - 1
README.md

@@ -283,7 +283,7 @@ git clone https://github.com/open-webui/open-webui.git
 cd open-webui/
 
 # Copying required .env file
-cp -RPp example.env .env
+cp -RPp .env.example .env
 
 # Building Frontend Using Node
 npm i

+ 24 - 11
backend/apps/images/main.py

@@ -33,7 +33,7 @@ app.add_middleware(
 )
 
 app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
-app.state.ENABLED = False
+app.state.ENABLED = app.state.AUTOMATIC1111_BASE_URL != ""
 
 
 @app.get("/enabled", response_model=bool)
@@ -129,20 +129,33 @@ def generate_image(
     form_data: GenerateImageForm,
     user=Depends(get_current_user),
 ):
-    if form_data.model:
-        set_model_handler(form_data.model)
 
-    width, height = tuple(map(int, form_data.size.split("x")))
+    print(form_data)
 
-    r = requests.get(
-        url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
-        json={
+    try:
+        if form_data.model:
+            set_model_handler(form_data.model)
+
+        width, height = tuple(map(int, form_data.size.split("x")))
+
+        data = {
             "prompt": form_data.prompt,
-            "negative_prompt": form_data.negative_prompt,
             "batch_size": form_data.n,
             "width": width,
             "height": height,
-        },
-    )
+        }
+
+        if form_data.negative_prompt != None:
+            data["negative_prompt"] = form_data.negative_prompt
+
+        print(data)
 
-    return r.json()
+        r = requests.post(
+            url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
+            json=data,
+        )
+
+        return r.json()
+    except Exception as e:
+        print(e)
+        raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))

+ 99 - 0
src/lib/apis/images/index.ts

@@ -1,5 +1,69 @@
 import { IMAGES_API_BASE_URL } from '$lib/constants';
 
+export const getImageGenerationEnabledStatus = async (token: string = '') => {
+	let error = null;
+
+	const res = await fetch(`${IMAGES_API_BASE_URL}/enabled`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			...(token && { authorization: `Bearer ${token}` })
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			if ('detail' in err) {
+				error = err.detail;
+			} else {
+				error = 'Server connection failed';
+			}
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const toggleImageGenerationEnabledStatus = async (token: string = '') => {
+	let error = null;
+
+	const res = await fetch(`${IMAGES_API_BASE_URL}/enabled/toggle`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			...(token && { authorization: `Bearer ${token}` })
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			if ('detail' in err) {
+				error = err.detail;
+			} else {
+				error = 'Server connection failed';
+			}
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
 export const getAUTOMATIC1111Url = async (token: string = '') => {
 	let error = null;
 
@@ -165,3 +229,38 @@ export const updateDefaultDiffusionModel = async (token: string = '', model: str
 
 	return res.model;
 };
+
+export const imageGenerations = async (token: string = '', prompt: string) => {
+	let error = null;
+
+	const res = await fetch(`${IMAGES_API_BASE_URL}/generations`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			...(token && { authorization: `Bearer ${token}` })
+		},
+		body: JSON.stringify({
+			prompt: prompt
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			if ('detail' in err) {
+				error = err.detail;
+			} else {
+				error = 'Server connection failed';
+			}
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};

+ 1 - 0
src/lib/components/chat/Messages.svelte

@@ -11,6 +11,7 @@
 	import ResponseMessage from './Messages/ResponseMessage.svelte';
 	import Placeholder from './Messages/Placeholder.svelte';
 	import Spinner from '../common/Spinner.svelte';
+	import { imageGenerations } from '$lib/apis/images';
 
 	export let chatId = '';
 	export let sendPrompt: Function;

+ 84 - 15
src/lib/components/chat/Messages/ResponseMessage.svelte

@@ -16,6 +16,7 @@
 
 	import { synthesizeOpenAISpeech } from '$lib/apis/openai';
 	import { extractSentences } from '$lib/utils';
+	import { imageGenerations } from '$lib/apis/images';
 
 	export let modelfiles = [];
 	export let message;
@@ -43,6 +44,8 @@
 
 	let loadingSpeech = false;
 
+	let generatingImage = false;
+
 	$: tokens = marked.lexer(message.content);
 
 	const renderer = new marked.Renderer();
@@ -267,6 +270,21 @@
 		renderStyling();
 	};
 
+	const generateImage = async (message) => {
+		generatingImage = true;
+		const res = await imageGenerations(localStorage.token, message.content);
+		console.log(res);
+
+		if (res) {
+			message.files = res.images.map((image) => ({
+				type: 'image',
+				url: `data:image/png;base64,${image}`
+			}));
+		}
+
+		generatingImage = false;
+	};
+
 	onMount(async () => {
 		await tick();
 		renderStyling();
@@ -295,6 +313,18 @@
 			{#if message.content === ''}
 				<Skeleton />
 			{:else}
+				{#if message.files}
+					<div class="my-2.5 w-full flex overflow-x-auto gap-2 flex-wrap">
+						{#each message.files as file}
+							<div>
+								{#if file.type === 'image'}
+									<img src={file.url} alt="input" class=" max-h-96 rounded-lg" draggable="false" />
+								{/if}
+							</div>
+						{/each}
+					</div>
+				{/if}
+
 				<div
 					class="prose chat-{message.role} w-full max-w-full dark:prose-invert prose-headings:my-0 prose-p:m-0 prose-p:-mb-6 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-img:my-0 prose-ul:-my-4 prose-ol:-my-4 prose-li:-my-3 prose-ul:-mb-6 prose-ol:-mb-8 prose-li:-mb-4 whitespace-pre-line"
 				>
@@ -601,23 +631,62 @@
 													? 'visible'
 													: 'invisible group-hover:visible'} p-1 rounded dark:hover:text-white hover:text-black transition"
 												on:click={() => {
-													// generateImage
+													if (!generatingImage) {
+														generateImage(message);
+													}
 												}}
 											>
-												<svg
-													xmlns="http://www.w3.org/2000/svg"
-													fill="none"
-													viewBox="0 0 24 24"
-													stroke-width="1.5"
-													stroke="currentColor"
-													class="w-4 h-4"
-												>
-													<path
-														stroke-linecap="round"
-														stroke-linejoin="round"
-														d="m2.25 15.75 5.159-5.159a2.25 2.25 0 0 1 3.182 0l5.159 5.159m-1.5-1.5 1.409-1.409a2.25 2.25 0 0 1 3.182 0l2.909 2.909m-18 3.75h16.5a1.5 1.5 0 0 0 1.5-1.5V6a1.5 1.5 0 0 0-1.5-1.5H3.75A1.5 1.5 0 0 0 2.25 6v12a1.5 1.5 0 0 0 1.5 1.5Zm10.5-11.25h.008v.008h-.008V8.25Zm.375 0a.375.375 0 1 1-.75 0 .375.375 0 0 1 .75 0Z"
-													/>
-												</svg>
+												{#if generatingImage}
+													<svg
+														class=" w-4 h-4"
+														fill="currentColor"
+														viewBox="0 0 24 24"
+														xmlns="http://www.w3.org/2000/svg"
+														><style>
+															.spinner_S1WN {
+																animation: spinner_MGfb 0.8s linear infinite;
+																animation-delay: -0.8s;
+															}
+															.spinner_Km9P {
+																animation-delay: -0.65s;
+															}
+															.spinner_JApP {
+																animation-delay: -0.5s;
+															}
+															@keyframes spinner_MGfb {
+																93.75%,
+																100% {
+																	opacity: 0.2;
+																}
+															}
+														</style><circle class="spinner_S1WN" cx="4" cy="12" r="3" /><circle
+															class="spinner_S1WN spinner_Km9P"
+															cx="12"
+															cy="12"
+															r="3"
+														/><circle
+															class="spinner_S1WN spinner_JApP"
+															cx="20"
+															cy="12"
+															r="3"
+														/></svg
+													>
+												{:else}
+													<svg
+														xmlns="http://www.w3.org/2000/svg"
+														fill="none"
+														viewBox="0 0 24 24"
+														stroke-width="1.5"
+														stroke="currentColor"
+														class="w-4 h-4"
+													>
+														<path
+															stroke-linecap="round"
+															stroke-linejoin="round"
+															d="m2.25 15.75 5.159-5.159a2.25 2.25 0 0 1 3.182 0l5.159 5.159m-1.5-1.5 1.409-1.409a2.25 2.25 0 0 1 3.182 0l2.909 2.909m-18 3.75h16.5a1.5 1.5 0 0 0 1.5-1.5V6a1.5 1.5 0 0 0-1.5-1.5H3.75A1.5 1.5 0 0 0 2.25 6v12a1.5 1.5 0 0 0 1.5 1.5Zm10.5-11.25h.008v.008h-.008V8.25Zm.375 0a.375.375 0 1 1-.75 0 .375.375 0 0 1 .75 0Z"
+														/>
+													</svg>
+												{/if}
 											</button>
 										{/if}
 

+ 7 - 2
src/lib/components/chat/Settings/Images.svelte

@@ -2,14 +2,17 @@
 	import toast from 'svelte-french-toast';
 
 	import { createEventDispatcher, onMount } from 'svelte';
-	import { user } from '$lib/stores';
+	import { config, user } from '$lib/stores';
 	import {
 		getAUTOMATIC1111Url,
 		getDefaultDiffusionModel,
 		getDiffusionModels,
+		getImageGenerationEnabledStatus,
+		toggleImageGenerationEnabledStatus,
 		updateAUTOMATIC1111Url,
 		updateDefaultDiffusionModel
 	} from '$lib/apis/images';
+	import { getBackendConfig } from '$lib/apis';
 	const dispatch = createEventDispatcher();
 
 	export let saveSettings: Function;
@@ -42,11 +45,13 @@
 	};
 
 	const toggleImageGeneration = async () => {
-		enableImageGeneration = !enableImageGeneration;
+		enableImageGeneration = await toggleImageGenerationEnabledStatus(localStorage.token);
+		config.set(await getBackendConfig(localStorage.token));
 	};
 
 	onMount(async () => {
 		if ($user.role === 'admin') {
+			enableImageGeneration = await getImageGenerationEnabledStatus(localStorage.token);
 			AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token);
 
 			if (AUTOMATIC1111_BASE_URL) {