Browse Source

feat: user valves integration

Timothy J. Baek 10 tháng trước cách đây
mục cha
commit
d362fd027e

+ 16 - 12
backend/apps/webui/models/functions.py

@@ -8,6 +8,8 @@ from apps.webui.internal.db import DB, JSONField
 from apps.webui.models.users import Users
 
 import json
+import copy
+
 
 from config import SRC_LOG_LEVELS
 
@@ -121,14 +123,15 @@ class FunctionsTable:
     ) -> Optional[dict]:
         try:
             user = Users.get_user_by_id(user_id)
+            user_settings = user.settings.model_dump()
 
             # Check if user has "functions" and "valves" settings
-            if "functions" not in user.settings:
-                user.settings["functions"] = {}
-            if "valves" not in user.settings["functions"]:
-                user.settings["functions"]["valves"] = {}
+            if "functions" not in user_settings:
+                user_settings["functions"] = {}
+            if "valves" not in user_settings["functions"]:
+                user_settings["functions"]["valves"] = {}
 
-            return user.settings["functions"]["valves"].get(id, {})
+            return user_settings["functions"]["valves"].get(id, {})
         except Exception as e:
             print(f"An error occurred: {e}")
             return None
@@ -138,20 +141,21 @@ class FunctionsTable:
     ) -> Optional[dict]:
         try:
             user = Users.get_user_by_id(user_id)
+            user_settings = user.settings.model_dump()
 
             # Check if user has "functions" and "valves" settings
-            if "functions" not in user.settings:
-                user.settings["functions"] = {}
-            if "valves" not in user.settings["functions"]:
-                user.settings["functions"]["valves"] = {}
+            if "functions" not in user_settings:
+                user_settings["functions"] = {}
+            if "valves" not in user_settings["functions"]:
+                user_settings["functions"]["valves"] = {}
 
-            user.settings["functions"]["valves"][id] = valves
+            user_settings["functions"]["valves"][id] = valves
 
             # Update the user settings in the database
-            query = Users.update_user_by_id(user_id, {"settings": user.settings})
+            query = Users.update_user_by_id(user_id, {"settings": user_settings})
             query.execute()
 
-            return user.settings["functions"]["valves"][id]
+            return user_settings["functions"]["valves"][id]
         except Exception as e:
             print(f"An error occurred: {e}")
             return None

+ 16 - 12
backend/apps/webui/models/tools.py

@@ -8,6 +8,8 @@ from apps.webui.internal.db import DB, JSONField
 from apps.webui.models.users import Users
 
 import json
+import copy
+
 
 from config import SRC_LOG_LEVELS
 
@@ -112,14 +114,15 @@ class ToolsTable:
     ) -> Optional[dict]:
         try:
             user = Users.get_user_by_id(user_id)
+            user_settings = user.settings.model_dump()
 
             # Check if user has "tools" and "valves" settings
-            if "tools" not in user.settings:
-                user.settings["tools"] = {}
-            if "valves" not in user.settings["tools"]:
-                user.settings["tools"]["valves"] = {}
+            if "tools" not in user_settings:
+                user_settings["tools"] = {}
+            if "valves" not in user_settings["tools"]:
+                user_settings["tools"]["valves"] = {}
 
-            return user.settings["tools"]["valves"].get(id, {})
+            return user_settings["tools"]["valves"].get(id, {})
         except Exception as e:
             print(f"An error occurred: {e}")
             return None
@@ -129,20 +132,21 @@ class ToolsTable:
     ) -> Optional[dict]:
         try:
             user = Users.get_user_by_id(user_id)
+            user_settings = user.settings.model_dump()
 
             # Check if user has "tools" and "valves" settings
-            if "tools" not in user.settings:
-                user.settings["tools"] = {}
-            if "valves" not in user.settings["tools"]:
-                user.settings["tools"]["valves"] = {}
+            if "tools" not in user_settings:
+                user_settings["tools"] = {}
+            if "valves" not in user_settings["tools"]:
+                user_settings["tools"]["valves"] = {}
 
-            user.settings["tools"]["valves"][id] = valves
+            user_settings["tools"]["valves"][id] = valves
 
             # Update the user settings in the database
-            query = Users.update_user_by_id(user_id, {"settings": user.settings})
+            query = Users.update_user_by_id(user_id, {"settings": user_settings})
             query.execute()
 
-            return user.settings["tools"]["valves"][id]
+            return user_settings["tools"]["valves"][id]
         except Exception as e:
             print(f"An error occurred: {e}")
             return None

+ 3 - 4
backend/apps/webui/routers/functions.py

@@ -145,7 +145,7 @@ async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user
 async def get_function_user_valves_spec_by_id(
     request: Request, id: str, user=Depends(get_verified_user)
 ):
-    function = Functions.get_tool_by_id(id)
+    function = Functions.get_function_by_id(id)
     if function:
         if id in request.app.state.FUNCTIONS:
             function_module = request.app.state.FUNCTIONS[id]
@@ -168,8 +168,7 @@ async def get_function_user_valves_spec_by_id(
 async def update_function_user_valves_by_id(
     request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
 ):
-
-    function = Functions.get_tool_by_id(id)
+    function = Functions.get_function_by_id(id)
 
     if function:
         if id in request.app.state.FUNCTIONS:
@@ -211,7 +210,7 @@ async def update_function_user_valves_by_id(
 
 
 @router.post("/id/{id}/update", response_model=Optional[FunctionModel])
-async def update_toolkit_by_id(
+async def update_function_by_id(
     request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user)
 ):
     function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")

+ 99 - 0
src/lib/apis/functions/index.ts

@@ -191,3 +191,102 @@ export const deleteFunctionById = async (token: string, id: string) => {
 
 	return res;
 };
+
+export const getUserValvesById = async (token: string, id: string) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/valves/user`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const getUserValvesSpecById = async (token: string, id: string) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/valves/user/spec`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const updateUserValvesById = async (token: string, id: string, valves: object) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/valves/user/update`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			...valves
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};

+ 99 - 0
src/lib/apis/tools/index.ts

@@ -191,3 +191,102 @@ export const deleteToolById = async (token: string, id: string) => {
 
 	return res;
 };
+
+export const getUserValvesById = async (token: string, id: string) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/valves/user`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const getUserValvesSpecById = async (token: string, id: string) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/valves/user/spec`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const updateUserValvesById = async (token: string, id: string, valves: object) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/valves/user/update`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			...valves
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};

+ 142 - 9
src/lib/components/chat/Settings/Valves.svelte

@@ -1,12 +1,24 @@
 <script lang="ts">
-	import { getBackendConfig } from '$lib/apis';
-	import { setDefaultPromptSuggestions } from '$lib/apis/configs';
-	import Switch from '$lib/components/common/Switch.svelte';
+	import { toast } from 'svelte-sonner';
+
 	import { config, functions, models, settings, tools, user } from '$lib/stores';
 	import { createEventDispatcher, onMount, getContext, tick } from 'svelte';
-	import { toast } from 'svelte-sonner';
+
+	import {
+		getUserValvesSpecById as getToolUserValvesSpecById,
+		getUserValvesById as getToolUserValvesById,
+		updateUserValvesById as updateToolUserValvesById
+	} from '$lib/apis/tools';
+	import {
+		getUserValvesSpecById as getFunctionUserValvesSpecById,
+		getUserValvesById as getFunctionUserValvesById,
+		updateUserValvesById as updateFunctionUserValvesById
+	} from '$lib/apis/functions';
+
 	import ManageModal from './Personalization/ManageModal.svelte';
 	import Tooltip from '$lib/components/common/Tooltip.svelte';
+	import Spinner from '$lib/components/common/Spinner.svelte';
+
 	const dispatch = createEventDispatcher();
 
 	const i18n = getContext('i18n');
@@ -16,15 +28,85 @@
 	let tab = 'tools';
 	let selectedId = '';
 
+	let loading = false;
+
+	let valvesSpec = null;
+	let valves = {};
+
+	const getUserValves = async () => {
+		loading = true;
+		if (tab === 'tools') {
+			valves = await getToolUserValvesById(localStorage.token, selectedId);
+			valvesSpec = await getToolUserValvesSpecById(localStorage.token, selectedId);
+		} else if (tab === 'functions') {
+			valves = await getFunctionUserValvesById(localStorage.token, selectedId);
+			valvesSpec = await getFunctionUserValvesSpecById(localStorage.token, selectedId);
+		}
+
+		if (valvesSpec) {
+			// Convert array to string
+			for (const property in valvesSpec.properties) {
+				if (valvesSpec.properties[property]?.type === 'array') {
+					valves[property] = (valves[property] ?? []).join(',');
+				}
+			}
+		}
+
+		loading = false;
+	};
+
+	const submitHandler = async () => {
+		if (valvesSpec) {
+			// Convert string to array
+			for (const property in valvesSpec.properties) {
+				if (valvesSpec.properties[property]?.type === 'array') {
+					valves[property] = (valves[property] ?? '').split(',').map((v) => v.trim());
+				}
+			}
+
+			if (tab === 'tools') {
+				const res = await updateToolUserValvesById(localStorage.token, selectedId, valves).catch(
+					(error) => {
+						toast.error(error);
+						return null;
+					}
+				);
+
+				if (res) {
+					toast.success('Valves updated');
+					valves = res;
+				}
+			} else if (tab === 'functions') {
+				const res = await updateFunctionUserValvesById(
+					localStorage.token,
+					selectedId,
+					valves
+				).catch((error) => {
+					toast.error(error);
+					return null;
+				});
+
+				if (res) {
+					toast.success('Valves updated');
+					valves = res;
+				}
+			}
+		}
+	};
+
 	$: if (tab) {
 		selectedId = '';
 	}
-	onMount(async () => {});
+
+	$: if (selectedId) {
+		getUserValves();
+	}
 </script>
 
 <form
 	class="flex flex-col h-full justify-between space-y-3 text-sm"
 	on:submit|preventDefault={() => {
+		submitHandler();
 		dispatch('save');
 	}}
 >
@@ -82,11 +164,62 @@
 			</div>
 		</div>
 
-		<hr class="dark:border-gray-800 my-3 w-full" />
+		{#if selectedId}
+			<hr class="dark:border-gray-800 my-3 w-full" />
 
-		<div>
-			<div class="flex items-center justify-between mb-1" />
-		</div>
+			<div>
+				{#if !loading}
+					{#if valvesSpec}
+						{#each Object.keys(valvesSpec.properties) as property, idx}
+							<div class=" py-0.5 w-full justify-between">
+								<div class="flex w-full justify-between">
+									<div class=" self-center text-xs font-medium">
+										{valvesSpec.properties[property].title}
+
+										{#if (valvesSpec?.required ?? []).includes(property)}
+											<span class=" text-gray-500">*required</span>
+										{/if}
+									</div>
+
+									<button
+										class="p-1 px-3 text-xs flex rounded transition"
+										type="button"
+										on:click={() => {
+											valves[property] = (valves[property] ?? null) === null ? '' : null;
+										}}
+									>
+										{#if (valves[property] ?? null) === null}
+											<span class="ml-2 self-center"> {$i18n.t('None')} </span>
+										{:else}
+											<span class="ml-2 self-center"> {$i18n.t('Custom')} </span>
+										{/if}
+									</button>
+								</div>
+
+								{#if (valves[property] ?? null) !== null}
+									<div class="flex mt-0.5 space-x-2">
+										<div class=" flex-1">
+											<input
+												class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+												type="text"
+												placeholder={valvesSpec.properties[property].title}
+												bind:value={valves[property]}
+												autocomplete="off"
+												required={(valvesSpec?.required ?? []).includes(property)}
+											/>
+										</div>
+									</div>
+								{/if}
+							</div>
+						{/each}
+					{:else}
+						<div>No valves</div>
+					{/if}
+				{:else}
+					<Spinner className="size-5" />
+				{/if}
+			</div>
+		{/if}
 	</div>
 
 	<div class="flex justify-end text-sm font-medium">