Timothy J. Baek il y a 10 mois
Parent
commit
627705a347

+ 57 - 11
backend/apps/webui/routers/functions.py

@@ -127,8 +127,8 @@ async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)):
     function = Functions.get_function_by_id(id)
     if function:
         try:
-            valves = Functions.get_function_valves_by_id(id)
-            return valves
+            function_valves = Functions.get_function_valves_by_id(id)
+            return function_valves.valves
         except Exception as e:
             raise HTTPException(
                 status_code=status.HTTP_400_BAD_REQUEST,
@@ -142,24 +142,70 @@ async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)):
 
 
 ############################
-# UpdateToolValves
+# GetFunctionValvesSpec
+############################
+
+
+@router.get("/id/{id}/valves/spec", response_model=Optional[dict])
+async def get_function_valves_spec_by_id(
+    request: Request, id: str, user=Depends(get_admin_user)
+):
+    function = Functions.get_function_by_id(id)
+    if function:
+        if id in request.app.state.FUNCTIONS:
+            function_module = request.app.state.FUNCTIONS[id]
+        else:
+            function_module, function_type = load_function_module_by_id(id)
+            request.app.state.FUNCTIONS[id] = function_module
+
+        if hasattr(function_module, "Valves"):
+            Valves = function_module.Valves
+            return Valves.schema()
+        return None
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+
+############################
+# UpdateFunctionValves
 ############################
 
 
 @router.post("/id/{id}/valves/update", response_model=Optional[dict])
-async def update_toolkit_valves_by_id(
-    id: str, form_data: dict, user=Depends(get_admin_user)
+async def update_function_valves_by_id(
+    request: Request, id: str, form_data: dict, user=Depends(get_admin_user)
 ):
     function = Functions.get_function_by_id(id)
     if function:
-        try:
-            valves = Functions.update_function_valves_by_id(id, form_data)
-            return valves
-        except Exception as e:
+
+        if id in request.app.state.FUNCTIONS:
+            function_module = request.app.state.FUNCTIONS[id]
+        else:
+            function_module, function_type = load_function_module_by_id(id)
+            request.app.state.FUNCTIONS[id] = function_module
+
+        if hasattr(function_module, "Valves"):
+            Valves = function_module.Valves
+
+            try:
+                valves = Valves(**form_data)
+                Functions.update_function_valves_by_id(id, valves.model_dump())
+                return valves.model_dump()
+            except Exception as e:
+                print(e)
+                raise HTTPException(
+                    status_code=status.HTTP_400_BAD_REQUEST,
+                    detail=ERROR_MESSAGES.DEFAULT(e),
+                )
+        else:
             raise HTTPException(
-                status_code=status.HTTP_400_BAD_REQUEST,
-                detail=ERROR_MESSAGES.DEFAULT(e),
+                status_code=status.HTTP_401_UNAUTHORIZED,
+                detail=ERROR_MESSAGES.NOT_FOUND,
             )
+
     else:
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,

+ 54 - 9
backend/apps/webui/routers/tools.py

@@ -133,8 +133,8 @@ async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)):
     toolkit = Tools.get_tool_by_id(id)
     if toolkit:
         try:
-            valves = Tools.get_tool_valves_by_id(id)
-            return valves
+            tool_valves = Tools.get_tool_valves_by_id(id)
+            return tool_valves.valves
         except Exception as e:
             raise HTTPException(
                 status_code=status.HTTP_400_BAD_REQUEST,
@@ -147,6 +147,34 @@ async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)):
         )
 
 
+############################
+# GetToolValvesSpec
+############################
+
+
+@router.get("/id/{id}/valves/spec", response_model=Optional[dict])
+async def get_toolkit_valves_spec_by_id(
+    request: Request, id: str, user=Depends(get_admin_user)
+):
+    toolkit = Tools.get_tool_by_id(id)
+    if toolkit:
+        if id in request.app.state.TOOLS:
+            toolkit_module = request.app.state.TOOLS[id]
+        else:
+            toolkit_module = load_toolkit_module_by_id(id)
+            request.app.state.TOOLS[id] = toolkit_module
+
+        if hasattr(toolkit_module, "UserValves"):
+            UserValves = toolkit_module.UserValves
+            return UserValves.schema()
+        return None
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+
 ############################
 # UpdateToolValves
 ############################
@@ -154,18 +182,35 @@ async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)):
 
 @router.post("/id/{id}/valves/update", response_model=Optional[dict])
 async def update_toolkit_valves_by_id(
-    id: str, form_data: dict, user=Depends(get_admin_user)
+    request: Request, id: str, form_data: dict, user=Depends(get_admin_user)
 ):
     toolkit = Tools.get_tool_by_id(id)
     if toolkit:
-        try:
-            valves = Tools.update_tool_valves_by_id(id, form_data)
-            return valves
-        except Exception as e:
+        if id in request.app.state.TOOLS:
+            toolkit_module = request.app.state.TOOLS[id]
+        else:
+            toolkit_module = load_toolkit_module_by_id(id)
+            request.app.state.TOOLS[id] = toolkit_module
+
+        if hasattr(toolkit_module, "Valves"):
+            Valves = toolkit_module.Valves
+
+            try:
+                valves = Valves(**form_data)
+                Tools.update_tool_valves_by_id(id, valves.model_dump())
+                return valves.model_dump()
+            except Exception as e:
+                print(e)
+                raise HTTPException(
+                    status_code=status.HTTP_400_BAD_REQUEST,
+                    detail=ERROR_MESSAGES.DEFAULT(e),
+                )
+        else:
             raise HTTPException(
-                status_code=status.HTTP_400_BAD_REQUEST,
-                detail=ERROR_MESSAGES.DEFAULT(e),
+                status_code=status.HTTP_401_UNAUTHORIZED,
+                detail=ERROR_MESSAGES.NOT_FOUND,
             )
+
     else:
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,

+ 32 - 0
src/lib/apis/functions/index.ts

@@ -256,6 +256,38 @@ export const getFunctionValvesById = async (token: string, id: string) => {
 	return res;
 };
 
+export const getFunctionValvesSpecById = async (token: string, id: string) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/valves/spec`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
 export const updateFunctionValvesById = async (token: string, id: string, valves: object) => {
 	let error = null;
 

+ 99 - 0
src/lib/apis/tools/index.ts

@@ -192,6 +192,105 @@ export const deleteToolById = async (token: string, id: string) => {
 	return res;
 };
 
+export const getToolValvesById = async (token: string, id: string) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/valves`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const getToolValvesSpecById = async (token: string, id: string) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/valves/spec`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const updateToolValvesById = async (token: string, id: string, valves: object) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/valves/update`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			...valves
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
 export const getUserValvesById = async (token: string, id: string) => {
 	let error = null;
 

+ 10 - 0
src/lib/components/workspace/Functions.svelte

@@ -24,6 +24,7 @@
 	import FunctionMenu from './Functions/FunctionMenu.svelte';
 	import EllipsisHorizontal from '../icons/EllipsisHorizontal.svelte';
 	import Switch from '../common/Switch.svelte';
+	import ValvesModal from './ValvesModal.svelte';
 
 	const i18n = getContext('i18n');
 
@@ -33,6 +34,9 @@
 	let showConfirm = false;
 	let query = '';
 
+	let showValvesModal = false;
+	let selectedFunction = null;
+
 	const shareHandler = async (tool) => {
 		console.log(tool);
 	};
@@ -175,6 +179,10 @@
 					<button
 						class="self-center w-fit text-sm px-2 py-2 dark:text-gray-300 dark:hover:text-white hover:bg-black/5 dark:hover:bg-white/5 rounded-xl"
 						type="button"
+						on:click={() => {
+							selectedFunction = func;
+							showValvesModal = true;
+						}}
 					>
 						<svg
 							xmlns="http://www.w3.org/2000/svg"
@@ -352,6 +360,8 @@
 	</a>
 </div>
 
+<ValvesModal bind:show={showValvesModal} type="function" id={selectedFunction?.id ?? null} />
+
 <ConfirmDialog
 	bind:show={showConfirm}
 	on:confirm={() => {

+ 10 - 0
src/lib/components/workspace/Tools.svelte

@@ -20,6 +20,7 @@
 	import ConfirmDialog from '../common/ConfirmDialog.svelte';
 	import ToolMenu from './Tools/ToolMenu.svelte';
 	import EllipsisHorizontal from '../icons/EllipsisHorizontal.svelte';
+	import ValvesModal from './ValvesModal.svelte';
 
 	const i18n = getContext('i18n');
 
@@ -29,6 +30,9 @@
 	let showConfirm = false;
 	let query = '';
 
+	let showValvesModal = false;
+	let selectedTool = null;
+
 	const shareHandler = async (tool) => {
 		console.log(tool);
 	};
@@ -169,6 +173,10 @@
 					<button
 						class="self-center w-fit text-sm px-2 py-2 dark:text-gray-300 dark:hover:text-white hover:bg-black/5 dark:hover:bg-white/5 rounded-xl"
 						type="button"
+						on:click={() => {
+							selectedTool = tool;
+							showValvesModal = true;
+						}}
 					>
 						<svg
 							xmlns="http://www.w3.org/2000/svg"
@@ -336,6 +344,8 @@
 	</a>
 </div>
 
+<ValvesModal bind:show={showValvesModal} type="tool" id={selectedTool?.id ?? null} />
+
 <ConfirmDialog
 	bind:show={showConfirm}
 	on:confirm={() => {

+ 119 - 76
src/lib/components/workspace/ValvesModal.svelte

@@ -5,6 +5,13 @@
 	import { addUser } from '$lib/apis/auths';
 
 	import Modal from '../common/Modal.svelte';
+	import {
+		getFunctionValvesById,
+		getFunctionValvesSpecById,
+		updateFunctionValvesById
+	} from '$lib/apis/functions';
+	import { getToolValvesById, getToolValvesSpecById, updateToolValvesById } from '$lib/apis/tools';
+	import Spinner from '../common/Spinner.svelte';
 
 	const i18n = getContext('i18n');
 	const dispatch = createEventDispatcher();
@@ -14,21 +21,57 @@
 	export let type = 'tool';
 	export let id = null;
 
+	let saving = false;
 	let loading = false;
 
-	let _user = {
-		name: '',
-		email: '',
-		password: '',
-		role: 'user'
-	};
+	let valvesSpec = null;
+	let valves = {};
 
 	const submitHandler = async () => {
-		const stopLoading = () => {
-			dispatch('save');
-			loading = false;
-		};
+		saving = true;
+
+		let res = null;
+
+		if (type === 'tool') {
+			res = await updateToolValvesById(localStorage.token, id, valves).catch((error) => {
+				toast.error(error);
+			});
+		} else if (type === 'function') {
+			res = await updateFunctionValvesById(localStorage.token, id, valves).catch((error) => {
+				toast.error(error);
+			});
+		}
+
+		if (res) {
+			toast.success('Valves updated successfully');
+		}
+
+		saving = false;
 	};
+
+	const initHandler = async () => {
+		loading = true;
+		valves = {};
+		valvesSpec = null;
+
+		if (type === 'tool') {
+			valves = await getToolValvesById(localStorage.token, id);
+			valvesSpec = await getToolValvesSpecById(localStorage.token, id);
+		} else if (type === 'function') {
+			valves = await getFunctionValvesById(localStorage.token, id);
+			valvesSpec = await getFunctionValvesSpecById(localStorage.token, id);
+		}
+
+		if (!valves) {
+			valves = {};
+		}
+
+		loading = false;
+	};
+
+	$: if (show) {
+		initHandler();
+	}
 </script>
 
 <Modal size="sm" bind:show>
@@ -63,81 +106,81 @@
 					}}
 				>
 					<div class="px-1">
-						<div class="flex flex-col w-full">
-							<div class=" mb-1 text-xs text-gray-500">{$i18n.t('Role')}</div>
-
-							<div class="flex-1">
-								<select
-									class="w-full capitalize rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 disabled:text-gray-500 dark:disabled:text-gray-500 outline-none"
-									bind:value={_user.role}
-									placeholder={$i18n.t('Enter Your Role')}
-									required
-								>
-									<option value="pending"> {$i18n.t('pending')} </option>
-									<option value="user"> {$i18n.t('user')} </option>
-									<option value="admin"> {$i18n.t('admin')} </option>
-								</select>
-							</div>
-						</div>
-
-						<div class="flex flex-col w-full mt-2">
-							<div class=" mb-1 text-xs text-gray-500">{$i18n.t('Name')}</div>
-
-							<div class="flex-1">
-								<input
-									class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 disabled:text-gray-500 dark:disabled:text-gray-500 outline-none"
-									type="text"
-									bind:value={_user.name}
-									placeholder={$i18n.t('Enter Your Full Name')}
-									autocomplete="off"
-									required
-								/>
-							</div>
-						</div>
-
-						<hr class=" dark:border-gray-800 my-3 w-full" />
-
-						<div class="flex flex-col w-full">
-							<div class=" mb-1 text-xs text-gray-500">{$i18n.t('Email')}</div>
-
-							<div class="flex-1">
-								<input
-									class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 disabled:text-gray-500 dark:disabled:text-gray-500 outline-none"
-									type="email"
-									bind:value={_user.email}
-									placeholder={$i18n.t('Enter Your Email')}
-									autocomplete="off"
-									required
-								/>
-							</div>
-						</div>
-
-						<div class="flex flex-col w-full mt-2">
-							<div class=" mb-1 text-xs text-gray-500">{$i18n.t('Password')}</div>
-
-							<div class="flex-1">
-								<input
-									class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 disabled:text-gray-500 dark:disabled:text-gray-500 outline-none"
-									type="password"
-									bind:value={_user.password}
-									placeholder={$i18n.t('Enter Your Password')}
-									autocomplete="off"
-								/>
-							</div>
-						</div>
+						{#if !loading}
+							{#if valvesSpec}
+								{#each Object.keys(valvesSpec.properties) as property, idx}
+									<div class=" py-0.5 w-full justify-between">
+										<div class="flex w-full justify-between">
+											<div class=" self-center text-xs font-medium">
+												{valvesSpec.properties[property].title}
+
+												{#if (valvesSpec?.required ?? []).includes(property)}
+													<span class=" text-gray-500">*required</span>
+												{/if}
+											</div>
+
+											<button
+												class="p-1 px-3 text-xs flex rounded transition"
+												type="button"
+												on:click={() => {
+													valves[property] = (valves[property] ?? null) === null ? '' : null;
+												}}
+											>
+												{#if (valves[property] ?? null) === null}
+													<span class="ml-2 self-center">
+														{#if (valvesSpec?.required ?? []).includes(property)}
+															{$i18n.t('None')}
+														{:else}
+															{$i18n.t('Default')}
+														{/if}
+													</span>
+												{:else}
+													<span class="ml-2 self-center"> {$i18n.t('Custom')} </span>
+												{/if}
+											</button>
+										</div>
+
+										{#if (valves[property] ?? null) !== null}
+											<div class="flex mt-0.5 mb-1.5 space-x-2">
+												<div class=" flex-1">
+													<input
+														class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+														type="text"
+														placeholder={valvesSpec.properties[property].title}
+														bind:value={valves[property]}
+														autocomplete="off"
+														required={(valvesSpec?.required ?? []).includes(property)}
+													/>
+												</div>
+											</div>
+										{/if}
+
+										{#if (valvesSpec.properties[property]?.description ?? null) !== null}
+											<div class="text-xs text-gray-500">
+												{valvesSpec.properties[property].description}
+											</div>
+										{/if}
+									</div>
+								{/each}
+							{:else}
+								<div>No valves</div>
+							{/if}
+						{:else}
+							<Spinner className="size-5" />
+						{/if}
 					</div>
 
 					<div class="flex justify-end pt-3 text-sm font-medium">
 						<button
-							class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg flex flex-row space-x-1 items-center {loading
+							class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg flex flex-row space-x-1 items-center {saving
 								? ' cursor-not-allowed'
 								: ''}"
 							type="submit"
-							disabled={loading}
+							disabled={saving}
 						>
-							{$i18n.t('Submit')}
+							{$i18n.t('Save')}
 
-							{#if loading}
+							{#if saving}
 								<div class="ml-2 self-center">
 									<svg
 										class=" w-4 h-4"