Pārlūkot izejas kodu

Merge pull request #3400 from open-webui/valves

feat: tools & functions valves
Timothy Jaeryang Baek 10 mēneši atpakaļ
vecāks
revīzija
65dbf9ba3f

+ 50 - 0
backend/apps/webui/internal/migrations/016_add_valves_and_is_active.py

@@ -0,0 +1,50 @@
+"""Peewee migrations -- 009_add_models.py.
+
+Some examples (model - class or model name)::
+
+    > Model = migrator.orm['table_name']            # Return model in current state by name
+    > Model = migrator.ModelClass                   # Return model in current state by name
+
+    > migrator.sql(sql)                             # Run custom SQL
+    > migrator.run(func, *args, **kwargs)           # Run python function with the given args
+    > migrator.create_model(Model)                  # Create a model (could be used as decorator)
+    > migrator.remove_model(model, cascade=True)    # Remove a model
+    > migrator.add_fields(model, **fields)          # Add fields to a model
+    > migrator.change_fields(model, **fields)       # Change fields
+    > migrator.remove_fields(model, *field_names, cascade=True)
+    > migrator.rename_field(model, old_field_name, new_field_name)
+    > migrator.rename_table(model, new_table_name)
+    > migrator.add_index(model, *col_names, unique=False)
+    > migrator.add_not_null(model, *field_names)
+    > migrator.add_default(model, field_name, default)
+    > migrator.add_constraint(model, name, sql)
+    > migrator.drop_index(model, *col_names)
+    > migrator.drop_not_null(model, *field_names)
+    > migrator.drop_constraints(model, *constraints)
+
+"""
+
+from contextlib import suppress
+
+import peewee as pw
+from peewee_migrate import Migrator
+
+
+with suppress(ImportError):
+    import playhouse.postgres_ext as pw_pext
+
+
+def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
+    """Write your migrations here."""
+
+    migrator.add_fields("tool", valves=pw.TextField(null=True))
+    migrator.add_fields("function", valves=pw.TextField(null=True))
+    migrator.add_fields("function", is_active=pw.BooleanField(default=False))
+
+
+def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
+    """Write your rollback migrations here."""
+
+    migrator.remove_fields("tool", "valves")
+    migrator.remove_fields("function", "valves")
+    migrator.remove_fields("function", "is_active")

+ 1 - 1
backend/apps/webui/main.py

@@ -103,7 +103,7 @@ async def get_status():
 
 
 async def get_pipe_models():
-    pipes = Functions.get_functions_by_type("pipe")
+    pipes = Functions.get_functions_by_type("pipe", active_only=True)
     pipe_models = []
 
     for pipe in pipes:

+ 57 - 9
backend/apps/webui/models/functions.py

@@ -28,6 +28,8 @@ class Function(Model):
     type = TextField()
     content = TextField()
     meta = JSONField()
+    valves = JSONField()
+    is_active = BooleanField(default=False)
     updated_at = BigIntegerField()
     created_at = BigIntegerField()
 
@@ -46,6 +48,7 @@ class FunctionModel(BaseModel):
     type: str
     content: str
     meta: FunctionMeta
+    is_active: bool = False
     updated_at: int  # timestamp in epoch
     created_at: int  # timestamp in epoch
 
@@ -61,6 +64,7 @@ class FunctionResponse(BaseModel):
     type: str
     name: str
     meta: FunctionMeta
+    is_active: bool
     updated_at: int  # timestamp in epoch
     created_at: int  # timestamp in epoch
 
@@ -72,6 +76,10 @@ class FunctionForm(BaseModel):
     meta: FunctionMeta
 
 
+class FunctionValves(BaseModel):
+    valves: Optional[dict] = None
+
+
 class FunctionsTable:
     def __init__(self, db):
         self.db = db
@@ -107,16 +115,56 @@ class FunctionsTable:
         except:
             return None
 
-    def get_functions(self) -> List[FunctionModel]:
-        return [
-            FunctionModel(**model_to_dict(function)) for function in Function.select()
-        ]
+    def get_functions(self, active_only=False) -> List[FunctionModel]:
+        if active_only:
+            return [
+                FunctionModel(**model_to_dict(function))
+                for function in Function.select().where(Function.is_active == True)
+            ]
+        else:
+            return [
+                FunctionModel(**model_to_dict(function))
+                for function in Function.select()
+            ]
+
+    def get_functions_by_type(
+        self, type: str, active_only=False
+    ) -> List[FunctionModel]:
+        if active_only:
+            return [
+                FunctionModel(**model_to_dict(function))
+                for function in Function.select().where(
+                    Function.type == type, Function.is_active == True
+                )
+            ]
+        else:
+            return [
+                FunctionModel(**model_to_dict(function))
+                for function in Function.select().where(Function.type == type)
+            ]
+
+    def get_function_valves_by_id(self, id: str) -> Optional[FunctionValves]:
+        try:
+            function = Function.get(Function.id == id)
+            return FunctionValves(**model_to_dict(function))
+        except Exception as e:
+            print(f"An error occurred: {e}")
+            return None
+
+    def update_function_valves_by_id(
+        self, id: str, valves: dict
+    ) -> Optional[FunctionValves]:
+        try:
+            query = Function.update(
+                **{"valves": valves},
+                updated_at=int(time.time()),
+            ).where(Function.id == id)
+            query.execute()
 
-    def get_functions_by_type(self, type: str) -> List[FunctionModel]:
-        return [
-            FunctionModel(**model_to_dict(function))
-            for function in Function.select().where(Function.type == type)
-        ]
+            function = Function.get(Function.id == id)
+            return FunctionValves(**model_to_dict(function))
+        except:
+            return None
 
     def get_user_valves_by_id_and_user_id(
         self, id: str, user_id: str

+ 26 - 0
backend/apps/webui/models/tools.py

@@ -28,6 +28,7 @@ class Tool(Model):
     content = TextField()
     specs = JSONField()
     meta = JSONField()
+    valves = JSONField()
     updated_at = BigIntegerField()
     created_at = BigIntegerField()
 
@@ -71,6 +72,10 @@ class ToolForm(BaseModel):
     meta: ToolMeta
 
 
+class ToolValves(BaseModel):
+    valves: Optional[dict] = None
+
+
 class ToolsTable:
     def __init__(self, db):
         self.db = db
@@ -109,6 +114,27 @@ class ToolsTable:
     def get_tools(self) -> List[ToolModel]:
         return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()]
 
+    def get_tool_valves_by_id(self, id: str) -> Optional[ToolValves]:
+        try:
+            tool = Tool.get(Tool.id == id)
+            return ToolValves(**model_to_dict(tool))
+        except Exception as e:
+            print(f"An error occurred: {e}")
+            return None
+
+    def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
+        try:
+            query = Tool.update(
+                **{"valves": valves},
+                updated_at=int(time.time()),
+            ).where(Tool.id == id)
+            query.execute()
+
+            tool = Tool.get(Tool.id == id)
+            return ToolValves(**model_to_dict(tool))
+        except:
+            return None
+
     def get_user_valves_by_id_and_user_id(
         self, id: str, user_id: str
     ) -> Optional[dict]:

+ 125 - 0
backend/apps/webui/routers/functions.py

@@ -117,6 +117,103 @@ async def get_function_by_id(id: str, user=Depends(get_admin_user)):
         )
 
 
+############################
+# GetFunctionValves
+############################
+
+
+@router.get("/id/{id}/valves", response_model=Optional[dict])
+async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)):
+    function = Functions.get_function_by_id(id)
+    if function:
+        try:
+            function_valves = Functions.get_function_valves_by_id(id)
+            return function_valves.valves
+        except Exception as e:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=ERROR_MESSAGES.DEFAULT(e),
+            )
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+
+############################
+# GetFunctionValvesSpec
+############################
+
+
+@router.get("/id/{id}/valves/spec", response_model=Optional[dict])
+async def get_function_valves_spec_by_id(
+    request: Request, id: str, user=Depends(get_admin_user)
+):
+    function = Functions.get_function_by_id(id)
+    if function:
+        if id in request.app.state.FUNCTIONS:
+            function_module = request.app.state.FUNCTIONS[id]
+        else:
+            function_module, function_type = load_function_module_by_id(id)
+            request.app.state.FUNCTIONS[id] = function_module
+
+        if hasattr(function_module, "Valves"):
+            Valves = function_module.Valves
+            return Valves.schema()
+        return None
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+
+############################
+# UpdateFunctionValves
+############################
+
+
+@router.post("/id/{id}/valves/update", response_model=Optional[dict])
+async def update_function_valves_by_id(
+    request: Request, id: str, form_data: dict, user=Depends(get_admin_user)
+):
+    function = Functions.get_function_by_id(id)
+    if function:
+
+        if id in request.app.state.FUNCTIONS:
+            function_module = request.app.state.FUNCTIONS[id]
+        else:
+            function_module, function_type = load_function_module_by_id(id)
+            request.app.state.FUNCTIONS[id] = function_module
+
+        if hasattr(function_module, "Valves"):
+            Valves = function_module.Valves
+
+            try:
+                form_data = {k: v for k, v in form_data.items() if v is not None}
+                valves = Valves(**form_data)
+                Functions.update_function_valves_by_id(id, valves.model_dump())
+                return valves.model_dump()
+            except Exception as e:
+                print(e)
+                raise HTTPException(
+                    status_code=status.HTTP_400_BAD_REQUEST,
+                    detail=ERROR_MESSAGES.DEFAULT(e),
+                )
+        else:
+            raise HTTPException(
+                status_code=status.HTTP_401_UNAUTHORIZED,
+                detail=ERROR_MESSAGES.NOT_FOUND,
+            )
+
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+
 ############################
 # FunctionUserValves
 ############################
@@ -181,6 +278,7 @@ async def update_function_user_valves_by_id(
             UserValves = function_module.UserValves
 
             try:
+                form_data = {k: v for k, v in form_data.items() if v is not None}
                 user_valves = UserValves(**form_data)
                 Functions.update_user_valves_by_id_and_user_id(
                     id, user.id, user_valves.model_dump()
@@ -204,6 +302,33 @@ async def update_function_user_valves_by_id(
         )
 
 
+############################
+# ToggleFunctionById
+############################
+
+
+@router.post("/id/{id}/toggle", response_model=Optional[FunctionModel])
+async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
+    function = Functions.get_function_by_id(id)
+    if function:
+        function = Functions.update_function_by_id(
+            id, {"is_active": not function.is_active}
+        )
+
+        if function:
+            return function
+        else:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
+            )
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+
 ############################
 # UpdateFunctionById
 ############################

+ 97 - 0
backend/apps/webui/routers/tools.py

@@ -123,6 +123,102 @@ async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
         )
 
 
+############################
+# GetToolValves
+############################
+
+
+@router.get("/id/{id}/valves", response_model=Optional[dict])
+async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)):
+    toolkit = Tools.get_tool_by_id(id)
+    if toolkit:
+        try:
+            tool_valves = Tools.get_tool_valves_by_id(id)
+            return tool_valves.valves
+        except Exception as e:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=ERROR_MESSAGES.DEFAULT(e),
+            )
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+
+############################
+# GetToolValvesSpec
+############################
+
+
+@router.get("/id/{id}/valves/spec", response_model=Optional[dict])
+async def get_toolkit_valves_spec_by_id(
+    request: Request, id: str, user=Depends(get_admin_user)
+):
+    toolkit = Tools.get_tool_by_id(id)
+    if toolkit:
+        if id in request.app.state.TOOLS:
+            toolkit_module = request.app.state.TOOLS[id]
+        else:
+            toolkit_module = load_toolkit_module_by_id(id)
+            request.app.state.TOOLS[id] = toolkit_module
+
+        if hasattr(toolkit_module, "UserValves"):
+            UserValves = toolkit_module.UserValves
+            return UserValves.schema()
+        return None
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+
+############################
+# UpdateToolValves
+############################
+
+
+@router.post("/id/{id}/valves/update", response_model=Optional[dict])
+async def update_toolkit_valves_by_id(
+    request: Request, id: str, form_data: dict, user=Depends(get_admin_user)
+):
+    toolkit = Tools.get_tool_by_id(id)
+    if toolkit:
+        if id in request.app.state.TOOLS:
+            toolkit_module = request.app.state.TOOLS[id]
+        else:
+            toolkit_module = load_toolkit_module_by_id(id)
+            request.app.state.TOOLS[id] = toolkit_module
+
+        if hasattr(toolkit_module, "Valves"):
+            Valves = toolkit_module.Valves
+
+            try:
+                form_data = {k: v for k, v in form_data.items() if v is not None}
+                valves = Valves(**form_data)
+                Tools.update_tool_valves_by_id(id, valves.model_dump())
+                return valves.model_dump()
+            except Exception as e:
+                print(e)
+                raise HTTPException(
+                    status_code=status.HTTP_400_BAD_REQUEST,
+                    detail=ERROR_MESSAGES.DEFAULT(e),
+                )
+        else:
+            raise HTTPException(
+                status_code=status.HTTP_401_UNAUTHORIZED,
+                detail=ERROR_MESSAGES.NOT_FOUND,
+            )
+
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+
 ############################
 # ToolUserValves
 ############################
@@ -187,6 +283,7 @@ async def update_toolkit_user_valves_by_id(
             UserValves = toolkit_module.UserValves
 
             try:
+                form_data = {k: v for k, v in form_data.items() if v is not None}
                 user_valves = UserValves(**form_data)
                 Tools.update_user_valves_by_id_and_user_id(
                     id, user.id, user_valves.model_dump()

+ 62 - 55
backend/main.py

@@ -376,70 +376,77 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                 )
             model = app.state.MODELS[model_id]
 
+            filter_ids = [
+                function.id
+                for function in Functions.get_functions_by_type(
+                    "filter", active_only=True
+                )
+            ]
             # Check if the model has any filters
             if "info" in model and "meta" in model["info"]:
-                for filter_id in model["info"]["meta"].get("filterIds", []):
-                    filter = Functions.get_function_by_id(filter_id)
-                    if filter:
-                        if filter_id in webui_app.state.FUNCTIONS:
-                            function_module = webui_app.state.FUNCTIONS[filter_id]
-                        else:
-                            function_module, function_type = load_function_module_by_id(
-                                filter_id
-                            )
-                            webui_app.state.FUNCTIONS[filter_id] = function_module
+                filter_ids.extend(model["info"]["meta"].get("filterIds", []))
+                filter_ids = list(set(filter_ids))
+
+            for filter_id in filter_ids:
+                filter = Functions.get_function_by_id(filter_id)
+                if filter:
+                    if filter_id in webui_app.state.FUNCTIONS:
+                        function_module = webui_app.state.FUNCTIONS[filter_id]
+                    else:
+                        function_module, function_type = load_function_module_by_id(
+                            filter_id
+                        )
+                        webui_app.state.FUNCTIONS[filter_id] = function_module
 
-                        # Check if the function has a file_handler variable
-                        if hasattr(function_module, "file_handler"):
-                            skip_files = function_module.file_handler
+                    # Check if the function has a file_handler variable
+                    if hasattr(function_module, "file_handler"):
+                        skip_files = function_module.file_handler
 
-                        try:
-                            if hasattr(function_module, "inlet"):
-                                inlet = function_module.inlet
-
-                                # Get the signature of the function
-                                sig = inspect.signature(inlet)
-                                params = {"body": data}
-
-                                if "__user__" in sig.parameters:
-                                    __user__ = {
-                                        "id": user.id,
-                                        "email": user.email,
-                                        "name": user.name,
-                                        "role": user.role,
-                                    }
-
-                                    try:
-                                        if hasattr(function_module, "UserValves"):
-                                            __user__["valves"] = (
-                                                function_module.UserValves(
-                                                    **Functions.get_user_valves_by_id_and_user_id(
-                                                        filter_id, user.id
-                                                    )
-                                                )
+                    try:
+                        if hasattr(function_module, "inlet"):
+                            inlet = function_module.inlet
+
+                            # Get the signature of the function
+                            sig = inspect.signature(inlet)
+                            params = {"body": data}
+
+                            if "__user__" in sig.parameters:
+                                __user__ = {
+                                    "id": user.id,
+                                    "email": user.email,
+                                    "name": user.name,
+                                    "role": user.role,
+                                }
+
+                                try:
+                                    if hasattr(function_module, "UserValves"):
+                                        __user__["valves"] = function_module.UserValves(
+                                            **Functions.get_user_valves_by_id_and_user_id(
+                                                filter_id, user.id
                                             )
-                                    except Exception as e:
-                                        print(e)
+                                        )
+                                except Exception as e:
+                                    print(e)
 
-                                    params = {**params, "__user__": __user__}
+                                params = {**params, "__user__": __user__}
 
-                                if "__id__" in sig.parameters:
-                                    params = {
-                                        **params,
-                                        "__id__": filter_id,
-                                    }
+                            if "__id__" in sig.parameters:
+                                params = {
+                                    **params,
+                                    "__id__": filter_id,
+                                }
 
-                                if inspect.iscoroutinefunction(inlet):
-                                    data = await inlet(**params)
-                                else:
-                                    data = inlet(**params)
+                            if inspect.iscoroutinefunction(inlet):
+                                data = await inlet(**params)
+                            else:
+                                data = inlet(**params)
 
-                        except Exception as e:
-                            print(f"Error: {e}")
-                            return JSONResponse(
-                                status_code=status.HTTP_400_BAD_REQUEST,
-                                content={"detail": str(e)},
-                            )
+                    except Exception as e:
+                        print(f"Error: {e}")
+                        return JSONResponse(
+                            status_code=status.HTTP_400_BAD_REQUEST,
+                            content={"detail": str(e)},
+                        )
 
             # Set the task model
             task_model_id = data["model"]

+ 131 - 0
src/lib/apis/functions/index.ts

@@ -192,6 +192,137 @@ export const deleteFunctionById = async (token: string, id: string) => {
 	return res;
 };
 
+export const toggleFunctionById = async (token: string, id: string) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/toggle`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const getFunctionValvesById = async (token: string, id: string) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/valves`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const getFunctionValvesSpecById = async (token: string, id: string) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/valves/spec`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const updateFunctionValvesById = async (token: string, id: string, valves: object) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/valves/update`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			...valves
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
 export const getUserValvesById = async (token: string, id: string) => {
 	let error = null;
 

+ 99 - 0
src/lib/apis/tools/index.ts

@@ -192,6 +192,105 @@ export const deleteToolById = async (token: string, id: string) => {
 	return res;
 };
 
+export const getToolValvesById = async (token: string, id: string) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/valves`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const getToolValvesSpecById = async (token: string, id: string) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/valves/spec`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const updateToolValvesById = async (token: string, id: string, valves: object) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/valves/update`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			...valves
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
 export const getUserValvesById = async (token: string, id: string) => {
 	let error = null;
 

+ 1 - 1
src/lib/components/chat/Settings/Valves.svelte

@@ -203,7 +203,7 @@
 								</div>
 
 								{#if (valves[property] ?? null) !== null}
-									<div class="flex mt-0.5 mb-1 space-x-2">
+									<div class="flex mt-0.5 mb-1.5 space-x-2">
 										<div class=" flex-1">
 											<input
 												class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"

+ 19 - 2
src/lib/components/workspace/Functions.svelte

@@ -13,7 +13,8 @@
 		deleteFunctionById,
 		exportFunctions,
 		getFunctionById,
-		getFunctions
+		getFunctions,
+		toggleFunctionById
 	} from '$lib/apis/functions';
 
 	import ArrowDownTray from '../icons/ArrowDownTray.svelte';
@@ -23,6 +24,7 @@
 	import FunctionMenu from './Functions/FunctionMenu.svelte';
 	import EllipsisHorizontal from '../icons/EllipsisHorizontal.svelte';
 	import Switch from '../common/Switch.svelte';
+	import ValvesModal from './ValvesModal.svelte';
 
 	const i18n = getContext('i18n');
 
@@ -32,6 +34,9 @@
 	let showConfirm = false;
 	let query = '';
 
+	let showValvesModal = false;
+	let selectedFunction = null;
+
 	const shareHandler = async (tool) => {
 		console.log(tool);
 	};
@@ -174,6 +179,10 @@
 					<button
 						class="self-center w-fit text-sm px-2 py-2 dark:text-gray-300 dark:hover:text-white hover:bg-black/5 dark:hover:bg-white/5 rounded-xl"
 						type="button"
+						on:click={() => {
+							selectedFunction = func;
+							showValvesModal = true;
+						}}
 					>
 						<svg
 							xmlns="http://www.w3.org/2000/svg"
@@ -224,7 +233,13 @@
 				</FunctionMenu>
 
 				<div class=" self-center mx-1">
-					<Switch />
+					<Switch
+						bind:state={func.is_active}
+						on:change={async (e) => {
+							toggleFunctionById(localStorage.token, func.id);
+							models.set(await getModels(localStorage.token));
+						}}
+					/>
 				</div>
 			</div>
 		</div>
@@ -345,6 +360,8 @@
 	</a>
 </div>
 
+<ValvesModal bind:show={showValvesModal} type="function" id={selectedFunction?.id ?? null} />
+
 <ConfirmDialog
 	bind:show={showConfirm}
 	on:confirm={() => {

+ 10 - 0
src/lib/components/workspace/Tools.svelte

@@ -20,6 +20,7 @@
 	import ConfirmDialog from '../common/ConfirmDialog.svelte';
 	import ToolMenu from './Tools/ToolMenu.svelte';
 	import EllipsisHorizontal from '../icons/EllipsisHorizontal.svelte';
+	import ValvesModal from './ValvesModal.svelte';
 
 	const i18n = getContext('i18n');
 
@@ -29,6 +30,9 @@
 	let showConfirm = false;
 	let query = '';
 
+	let showValvesModal = false;
+	let selectedTool = null;
+
 	const shareHandler = async (tool) => {
 		console.log(tool);
 	};
@@ -169,6 +173,10 @@
 					<button
 						class="self-center w-fit text-sm px-2 py-2 dark:text-gray-300 dark:hover:text-white hover:bg-black/5 dark:hover:bg-white/5 rounded-xl"
 						type="button"
+						on:click={() => {
+							selectedTool = tool;
+							showValvesModal = true;
+						}}
 					>
 						<svg
 							xmlns="http://www.w3.org/2000/svg"
@@ -336,6 +344,8 @@
 	</a>
 </div>
 
+<ValvesModal bind:show={showValvesModal} type="tool" id={selectedTool?.id ?? null} />
+
 <ConfirmDialog
 	bind:show={showConfirm}
 	on:confirm={() => {

+ 110 - 209
src/lib/components/workspace/ValvesModal.svelte

@@ -5,115 +5,73 @@
 	import { addUser } from '$lib/apis/auths';
 
 	import Modal from '../common/Modal.svelte';
-	import { WEBUI_BASE_URL } from '$lib/constants';
+	import {
+		getFunctionValvesById,
+		getFunctionValvesSpecById,
+		updateFunctionValvesById
+	} from '$lib/apis/functions';
+	import { getToolValvesById, getToolValvesSpecById, updateToolValvesById } from '$lib/apis/tools';
+	import Spinner from '../common/Spinner.svelte';
 
 	const i18n = getContext('i18n');
 	const dispatch = createEventDispatcher();
 
 	export let show = false;
 
-	let loading = false;
-	let tab = '';
-	let inputFiles;
+	export let type = 'tool';
+	export let id = null;
 
-	let _user = {
-		name: '',
-		email: '',
-		password: '',
-		role: 'user'
-	};
+	let saving = false;
+	let loading = false;
 
-	$: if (show) {
-		_user = {
-			name: '',
-			email: '',
-			password: '',
-			role: 'user'
-		};
-	}
+	let valvesSpec = null;
+	let valves = {};
 
 	const submitHandler = async () => {
-		const stopLoading = () => {
-			dispatch('save');
-			loading = false;
-		};
+		saving = true;
 
-		if (tab === '') {
-			loading = true;
+		let res = null;
 
-			const res = await addUser(
-				localStorage.token,
-				_user.name,
-				_user.email,
-				_user.password,
-				_user.role
-			).catch((error) => {
+		if (type === 'tool') {
+			res = await updateToolValvesById(localStorage.token, id, valves).catch((error) => {
 				toast.error(error);
 			});
+		} else if (type === 'function') {
+			res = await updateFunctionValvesById(localStorage.token, id, valves).catch((error) => {
+				toast.error(error);
+			});
+		}
 
-			if (res) {
-				stopLoading();
-				show = false;
-			}
-		} else {
-			if (inputFiles) {
-				loading = true;
-
-				const file = inputFiles[0];
-				const reader = new FileReader();
-
-				reader.onload = async (e) => {
-					const csv = e.target.result;
-					const rows = csv.split('\n');
-
-					let userCount = 0;
-
-					for (const [idx, row] of rows.entries()) {
-						const columns = row.split(',').map((col) => col.trim());
-						console.log(idx, columns);
-
-						if (idx > 0) {
-							if (
-								columns.length === 4 &&
-								['admin', 'user', 'pending'].includes(columns[3].toLowerCase())
-							) {
-								const res = await addUser(
-									localStorage.token,
-									columns[0],
-									columns[1],
-									columns[2],
-									columns[3].toLowerCase()
-								).catch((error) => {
-									toast.error(`Row ${idx + 1}: ${error}`);
-									return null;
-								});
-
-								if (res) {
-									userCount = userCount + 1;
-								}
-							} else {
-								toast.error(`Row ${idx + 1}: invalid format.`);
-							}
-						}
-					}
-
-					toast.success(`Successfully imported ${userCount} users.`);
-					inputFiles = null;
-					const uploadInputElement = document.getElementById('upload-user-csv-input');
+		if (res) {
+			toast.success('Valves updated successfully');
+		}
 
-					if (uploadInputElement) {
-						uploadInputElement.value = null;
-					}
+		saving = false;
+	};
 
-					stopLoading();
-				};
+	const initHandler = async () => {
+		loading = true;
+		valves = {};
+		valvesSpec = null;
+
+		if (type === 'tool') {
+			valves = await getToolValvesById(localStorage.token, id);
+			valvesSpec = await getToolValvesSpecById(localStorage.token, id);
+		} else if (type === 'function') {
+			valves = await getFunctionValvesById(localStorage.token, id);
+			valvesSpec = await getFunctionValvesSpecById(localStorage.token, id);
+		}
 
-				reader.readAsText(file);
-			} else {
-				toast.error($i18n.t('File not found.'));
-			}
+		if (!valves) {
+			valves = {};
 		}
+
+		loading = false;
 	};
+
+	$: if (show) {
+		initHandler();
+	}
 </script>
 
 <Modal size="sm" bind:show>
@@ -147,139 +105,82 @@
 						submitHandler();
 					}}
 				>
-					<div class="flex text-center text-sm font-medium rounded-xl bg-transparent/10 p-1 mb-2">
-						<button
-							class="w-full rounded-lg p-1.5 {tab === '' ? 'bg-gray-50 dark:bg-gray-850' : ''}"
-							type="button"
-							on:click={() => {
-								tab = '';
-							}}>{$i18n.t('Form')}</button
-						>
-
-						<button
-							class="w-full rounded-lg p-1 {tab === 'import' ? 'bg-gray-50 dark:bg-gray-850' : ''}"
-							type="button"
-							on:click={() => {
-								tab = 'import';
-							}}>{$i18n.t('CSV Import')}</button
-						>
-					</div>
 					<div class="px-1">
-						{#if tab === ''}
-							<div class="flex flex-col w-full">
-								<div class=" mb-1 text-xs text-gray-500">{$i18n.t('Role')}</div>
-
-								<div class="flex-1">
-									<select
-										class="w-full capitalize rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 disabled:text-gray-500 dark:disabled:text-gray-500 outline-none"
-										bind:value={_user.role}
-										placeholder={$i18n.t('Enter Your Role')}
-										required
-									>
-										<option value="pending"> {$i18n.t('pending')} </option>
-										<option value="user"> {$i18n.t('user')} </option>
-										<option value="admin"> {$i18n.t('admin')} </option>
-									</select>
-								</div>
-							</div>
-
-							<div class="flex flex-col w-full mt-2">
-								<div class=" mb-1 text-xs text-gray-500">{$i18n.t('Name')}</div>
-
-								<div class="flex-1">
-									<input
-										class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 disabled:text-gray-500 dark:disabled:text-gray-500 outline-none"
-										type="text"
-										bind:value={_user.name}
-										placeholder={$i18n.t('Enter Your Full Name')}
-										autocomplete="off"
-										required
-									/>
-								</div>
-							</div>
-
-							<hr class=" dark:border-gray-800 my-3 w-full" />
-
-							<div class="flex flex-col w-full">
-								<div class=" mb-1 text-xs text-gray-500">{$i18n.t('Email')}</div>
-
-								<div class="flex-1">
-									<input
-										class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 disabled:text-gray-500 dark:disabled:text-gray-500 outline-none"
-										type="email"
-										bind:value={_user.email}
-										placeholder={$i18n.t('Enter Your Email')}
-										autocomplete="off"
-										required
-									/>
-								</div>
-							</div>
-
-							<div class="flex flex-col w-full mt-2">
-								<div class=" mb-1 text-xs text-gray-500">{$i18n.t('Password')}</div>
-
-								<div class="flex-1">
-									<input
-										class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 disabled:text-gray-500 dark:disabled:text-gray-500 outline-none"
-										type="password"
-										bind:value={_user.password}
-										placeholder={$i18n.t('Enter Your Password')}
-										autocomplete="off"
-									/>
-								</div>
-							</div>
-						{:else if tab === 'import'}
-							<div>
-								<div class="mb-3 w-full">
-									<input
-										id="upload-user-csv-input"
-										hidden
-										bind:files={inputFiles}
-										type="file"
-										accept=".csv"
-									/>
-
-									<button
-										class="w-full text-sm font-medium py-3 bg-transparent hover:bg-gray-100 border border-dashed dark:border-gray-800 dark:hover:bg-gray-850 text-center rounded-xl"
-										type="button"
-										on:click={() => {
-											document.getElementById('upload-user-csv-input')?.click();
-										}}
-									>
-										{#if inputFiles}
-											{inputFiles.length > 0 ? `${inputFiles.length}` : ''} document(s) selected.
-										{:else}
-											{$i18n.t('Click here to select a csv file.')}
+						{#if !loading}
+							{#if valvesSpec}
+								{#each Object.keys(valvesSpec.properties) as property, idx}
+									<div class=" py-0.5 w-full justify-between">
+										<div class="flex w-full justify-between">
+											<div class=" self-center text-xs font-medium">
+												{valvesSpec.properties[property].title}
+
+												{#if (valvesSpec?.required ?? []).includes(property)}
+													<span class=" text-gray-500">*required</span>
+												{/if}
+											</div>
+
+											<button
+												class="p-1 px-3 text-xs flex rounded transition"
+												type="button"
+												on:click={() => {
+													valves[property] = (valves[property] ?? null) === null ? '' : null;
+												}}
+											>
+												{#if (valves[property] ?? null) === null}
+													<span class="ml-2 self-center">
+														{#if (valvesSpec?.required ?? []).includes(property)}
+															{$i18n.t('None')}
+														{:else}
+															{$i18n.t('Default')}
+														{/if}
+													</span>
+												{:else}
+													<span class="ml-2 self-center"> {$i18n.t('Custom')} </span>
+												{/if}
+											</button>
+										</div>
+
+										{#if (valves[property] ?? null) !== null}
+											<div class="flex mt-0.5 mb-1.5 space-x-2">
+												<div class=" flex-1">
+													<input
+														class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+														type="text"
+														placeholder={valvesSpec.properties[property].title}
+														bind:value={valves[property]}
+														autocomplete="off"
+														required={(valvesSpec?.required ?? []).includes(property)}
+													/>
+												</div>
+											</div>
 										{/if}
-									</button>
-								</div>
 
-								<div class=" text-xs text-gray-500">
-									ⓘ {$i18n.t(
-										'Ensure your CSV file includes 4 columns in this order: Name, Email, Password, Role.'
-									)}
-									<a
-										class="underline dark:text-gray-200"
-										href="{WEBUI_BASE_URL}/static/user-import.csv"
-									>
-										{$i18n.t('Click here to download user import template file.')}
-									</a>
-								</div>
-							</div>
+										{#if (valvesSpec.properties[property]?.description ?? null) !== null}
+											<div class="text-xs text-gray-500">
+												{valvesSpec.properties[property].description}
+											</div>
+										{/if}
+									</div>
+								{/each}
+							{:else}
+								<div>No valves</div>
+							{/if}
+						{:else}
+							<Spinner className="size-5" />
 						{/if}
 					</div>
 
 					<div class="flex justify-end pt-3 text-sm font-medium">
 						<button
-							class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg flex flex-row space-x-1 items-center {loading
+							class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg flex flex-row space-x-1 items-center {saving
 								? ' cursor-not-allowed'
 								: ''}"
 							type="submit"
-							disabled={loading}
+							disabled={saving}
 						>
-							{$i18n.t('Submit')}
+							{$i18n.t('Save')}
 
-							{#if loading}
+							{#if saving}
 								<div class="ml-2 self-center">
 									<svg
 										class=" w-4 h-4"

+ 1 - 1
src/routes/(app)/workspace/+layout.svelte

@@ -9,7 +9,7 @@
 	const i18n = getContext('i18n');
 
 	onMount(async () => {
-		functions.set(await getFunctions(localStorage.token));
+		// functions.set(await getFunctions(localStorage.token));
 	});
 </script>