Browse Source

feat: user valves endpoints

Timothy J. Baek 10 months ago
parent
commit
15fc23df87

+ 41 - 0
backend/apps/webui/models/functions.py

@@ -5,6 +5,7 @@ from typing import List, Union, Optional
 import time
 import logging
 from apps.webui.internal.db import DB, JSONField
+from apps.webui.models.users import Users
 
 import json
 
@@ -115,6 +116,46 @@ class FunctionsTable:
             for function in Function.select().where(Function.type == type)
         ]
 
+    def get_user_valves_by_id_and_user_id(
+        self, id: str, user_id: str
+    ) -> Optional[dict]:
+        try:
+            user = Users.get_user_by_id(user_id)
+
+            # Check if user has "functions" and "valves" settings
+            if "functions" not in user.settings:
+                user.settings["functions"] = {}
+            if "valves" not in user.settings["functions"]:
+                user.settings["functions"]["valves"] = {}
+
+            return user.settings["functions"]["valves"].get(id, {})
+        except Exception as e:
+            print(f"An error occurred: {e}")
+            return None
+
+    def update_user_valves_by_id_and_user_id(
+        self, id: str, user_id: str, valves: dict
+    ) -> Optional[dict]:
+        try:
+            user = Users.get_user_by_id(user_id)
+
+            # Check if user has "functions" and "valves" settings
+            if "functions" not in user.settings:
+                user.settings["functions"] = {}
+            if "valves" not in user.settings["functions"]:
+                user.settings["functions"]["valves"] = {}
+
+            user.settings["functions"]["valves"][id] = valves
+
+            # Update the user settings in the database
+            query = Users.update_user_by_id(user_id, {"settings": user.settings})
+            query.execute()
+
+            return user.settings["functions"]["valves"][id]
+        except Exception as e:
+            print(f"An error occurred: {e}")
+            return None
+
     def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
         try:
             query = Function.update(

+ 41 - 0
backend/apps/webui/models/tools.py

@@ -5,6 +5,7 @@ from typing import List, Union, Optional
 import time
 import logging
 from apps.webui.internal.db import DB, JSONField
+from apps.webui.models.users import Users
 
 import json
 
@@ -106,6 +107,46 @@ class ToolsTable:
     def get_tools(self) -> List[ToolModel]:
         return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()]
 
+    def get_user_valves_by_id_and_user_id(
+        self, id: str, user_id: str
+    ) -> Optional[dict]:
+        try:
+            user = Users.get_user_by_id(user_id)
+
+            # Check if user has "tools" and "valves" settings
+            if "tools" not in user.settings:
+                user.settings["tools"] = {}
+            if "valves" not in user.settings["tools"]:
+                user.settings["tools"]["valves"] = {}
+
+            return user.settings["tools"]["valves"].get(id, {})
+        except Exception as e:
+            print(f"An error occurred: {e}")
+            return None
+
+    def update_user_valves_by_id_and_user_id(
+        self, id: str, user_id: str, valves: dict
+    ) -> Optional[dict]:
+        try:
+            user = Users.get_user_by_id(user_id)
+
+            # Check if user has "tools" and "valves" settings
+            if "tools" not in user.settings:
+                user.settings["tools"] = {}
+            if "valves" not in user.settings["tools"]:
+                user.settings["tools"]["valves"] = {}
+
+            user.settings["tools"]["valves"][id] = valves
+
+            # Update the user settings in the database
+            query = Users.update_user_by_id(user_id, {"settings": user.settings})
+            query.execute()
+
+            return user.settings["tools"]["valves"][id]
+        except Exception as e:
+            print(f"An error occurred: {e}")
+            return None
+
     def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
         try:
             query = Tool.update(

+ 88 - 0
backend/apps/webui/routers/functions.py

@@ -117,6 +117,94 @@ async def get_function_by_id(id: str, user=Depends(get_admin_user)):
         )
 
 
+############################
+# FunctionUserValves
+############################
+
+
+@router.get("/id/{id}/valves/user", response_model=Optional[dict])
+async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user)):
+    function = Functions.get_function_by_id(id)
+    if function:
+        try:
+            user_valves = Functions.get_user_valves_by_id_and_user_id(id, user.id)
+            return user_valves
+        except Exception as e:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=ERROR_MESSAGES.DEFAULT(e),
+            )
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+
+@router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
+async def get_function_user_valves_spec_by_id(
+    request: Request, id: str, user=Depends(get_verified_user)
+):
+    function = Functions.get_tool_by_id(id)
+    if function:
+        if id in request.app.state.FUNCTIONS:
+            function_module = request.app.state.FUNCTIONS[id]
+        else:
+            function_module, function_type = load_function_module_by_id(id)
+            request.app.state.FUNCTIONS[id] = function_module
+
+        if hasattr(function_module, "UserValves"):
+            UserValves = function_module.UserValves
+            return UserValves.schema()
+        return None
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+
+@router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
+async def update_function_user_valves_by_id(
+    request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
+):
+
+    function = Functions.get_tool_by_id(id)
+
+    if function:
+        if id in request.app.state.FUNCTIONS:
+            function_module = request.app.state.FUNCTIONS[id]
+        else:
+            function_module, function_type = load_function_module_by_id(id)
+            request.app.state.FUNCTIONS[id] = function_module
+
+        if hasattr(function_module, "UserValves"):
+            UserValves = function_module.UserValves
+
+            try:
+                user_valves = UserValves(**form_data)
+                Functions.update_user_valves_by_id_and_user_id(
+                    id, user.id, user_valves.model_dump()
+                )
+                return user_valves.model_dump()
+            except Exception as e:
+                print(e)
+                raise HTTPException(
+                    status_code=status.HTTP_400_BAD_REQUEST,
+                    detail=ERROR_MESSAGES.DEFAULT(e),
+                )
+        else:
+            raise HTTPException(
+                status_code=status.HTTP_401_UNAUTHORIZED,
+                detail=ERROR_MESSAGES.NOT_FOUND,
+            )
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+
 ############################
 # UpdateFunctionById
 ############################

+ 91 - 2
backend/apps/webui/routers/tools.py

@@ -6,10 +6,12 @@ from fastapi import APIRouter
 from pydantic import BaseModel
 import json
 
+
+from apps.webui.models.users import Users
 from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
 from apps.webui.utils import load_toolkit_module_by_id
 
-from utils.utils import get_current_user, get_admin_user
+from utils.utils import get_admin_user, get_verified_user
 from utils.tools import get_tools_specs
 from constants import ERROR_MESSAGES
 
@@ -32,7 +34,7 @@ router = APIRouter()
 
 
 @router.get("/", response_model=List[ToolResponse])
-async def get_toolkits(user=Depends(get_current_user)):
+async def get_toolkits(user=Depends(get_verified_user)):
     toolkits = [toolkit for toolkit in Tools.get_tools()]
     return toolkits
 
@@ -121,6 +123,93 @@ async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
         )
 
 
+############################
+# ToolUserValves
+############################
+
+
+@router.get("/id/{id}/valves/user", response_model=Optional[dict])
+async def get_toolkit_user_valves_by_id(id: str, user=Depends(get_verified_user)):
+    toolkit = Tools.get_tool_by_id(id)
+    if toolkit:
+        try:
+            user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
+            return user_valves
+        except Exception as e:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=ERROR_MESSAGES.DEFAULT(e),
+            )
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+
+@router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
+async def get_toolkit_user_valves_spec_by_id(
+    request: Request, id: str, user=Depends(get_verified_user)
+):
+    toolkit = Tools.get_tool_by_id(id)
+    if toolkit:
+        if id in request.app.state.TOOLS:
+            toolkit_module = request.app.state.TOOLS[id]
+        else:
+            toolkit_module = load_toolkit_module_by_id(id)
+            request.app.state.TOOLS[id] = toolkit_module
+
+        if hasattr(toolkit_module, "UserValves"):
+            UserValves = toolkit_module.UserValves
+            return UserValves.schema()
+        return None
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+
+@router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
+async def update_toolkit_user_valves_by_id(
+    request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
+):
+    toolkit = Tools.get_tool_by_id(id)
+
+    if toolkit:
+        if id in request.app.state.TOOLS:
+            toolkit_module = request.app.state.TOOLS[id]
+        else:
+            toolkit_module = load_toolkit_module_by_id(id)
+            request.app.state.TOOLS[id] = toolkit_module
+
+        if hasattr(toolkit_module, "UserValves"):
+            UserValves = toolkit_module.UserValves
+
+            try:
+                user_valves = UserValves(**form_data)
+                Tools.update_user_valves_by_id_and_user_id(
+                    id, user.id, user_valves.model_dump()
+                )
+                return user_valves.model_dump()
+            except Exception as e:
+                print(e)
+                raise HTTPException(
+                    status_code=status.HTTP_400_BAD_REQUEST,
+                    detail=ERROR_MESSAGES.DEFAULT(e),
+                )
+        else:
+            raise HTTPException(
+                status_code=status.HTTP_401_UNAUTHORIZED,
+                detail=ERROR_MESSAGES.NOT_FOUND,
+            )
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+
 ############################
 # UpdateToolkitById
 ############################

+ 37 - 22
src/lib/components/chat/Settings/Valves.svelte

@@ -29,25 +29,28 @@
 	}}
 >
 	<div class="flex flex-col pr-1.5 overflow-y-scroll max-h-[25rem]">
-		<div class="flex text-center text-sm font-medium rounded-xl bg-transparent/10 p-1 mb-2">
-			<button
-				class="w-full rounded-lg p-1.5 {tab === 'tools' ? 'bg-gray-50 dark:bg-gray-850' : ''}"
-				type="button"
-				on:click={() => {
-					tab = 'tools';
-				}}>{$i18n.t('Tools')}</button
-			>
+		<div>
+			<div class="flex items-center justify-between mb-2">
+				<Tooltip content="">
+					<div class="text-sm font-medium">
+						{$i18n.t('Manage Valves')}
+					</div>
+				</Tooltip>
 
-			<button
-				class="w-full rounded-lg p-1 {tab === 'functions' ? 'bg-gray-50 dark:bg-gray-850' : ''}"
-				type="button"
-				on:click={() => {
-					tab = 'functions';
-				}}>{$i18n.t('Functions')}</button
-			>
+				<div class=" self-end">
+					<select
+						class=" dark:bg-gray-900 w-fit pr-8 rounded text-xs bg-transparent outline-none text-right"
+						bind:value={tab}
+						placeholder="Select"
+					>
+						<option value="tools">{$i18n.t('Tools')}</option>
+						<option value="functions">{$i18n.t('Functions')}</option>
+					</select>
+				</div>
+			</div>
 		</div>
 
-		<div class="space-y-1 px-1">
+		<div class="space-y-1">
 			<div class="flex gap-2">
 				<div class="flex-1">
 					<select
@@ -57,18 +60,30 @@
 							await tick();
 						}}
 					>
-						<option value="" selected disabled class="bg-gray-100 dark:bg-gray-700"
-							>{$i18n.t('Select a tool/function')}</option
-						>
+						{#if tab === 'tools'}
+							<option value="" selected disabled class="bg-gray-100 dark:bg-gray-700"
+								>{$i18n.t('Select a tool')}</option
+							>
 
-						{#each $tools as tool, toolIdx}
-							<option value={tool.id} class="bg-gray-100 dark:bg-gray-700">{tool.name}</option>
-						{/each}
+							{#each $tools as tool, toolIdx}
+								<option value={tool.id} class="bg-gray-100 dark:bg-gray-700">{tool.name}</option>
+							{/each}
+						{:else if tab === 'functions'}
+							<option value="" selected disabled class="bg-gray-100 dark:bg-gray-700"
+								>{$i18n.t('Select a function')}</option
+							>
+
+							{#each $functions as func, funcIdx}
+								<option value={func.id} class="bg-gray-100 dark:bg-700">{func.name}</option>
+							{/each}
+						{/if}
 					</select>
 				</div>
 			</div>
 		</div>
 
+		<hr class="dark:border-gray-800 my-3 w-full" />
+
 		<div>
 			<div class="flex items-center justify-between mb-1" />
 		</div>

+ 6 - 1
src/routes/(app)/+layout.svelte

@@ -29,13 +29,15 @@
 		showChangelog,
 		config,
 		showCallOverlay,
-		tools
+		tools,
+		functions
 	} from '$lib/stores';
 
 	import SettingsModal from '$lib/components/chat/SettingsModal.svelte';
 	import Sidebar from '$lib/components/layout/Sidebar.svelte';
 	import ChangelogModal from '$lib/components/ChangelogModal.svelte';
 	import AccountPending from '$lib/components/layout/Overlay/AccountPending.svelte';
+	import { getFunctions } from '$lib/apis/functions';
 
 	const i18n = getContext('i18n');
 
@@ -93,6 +95,9 @@
 				(async () => {
 					tools.set(await getTools(localStorage.token));
 				})(),
+				(async () => {
+					functions.set(await getFunctions(localStorage.token));
+				})(),
 				(async () => {
 					banners.set(await getBanners(localStorage.token));
 				})(),