Browse Source

feat: action function

Timothy J. Baek 9 months ago
parent
commit
eb10001eb7

+ 9 - 0
backend/apps/webui/models/functions.py

@@ -167,6 +167,15 @@ class FunctionsTable:
                 .all()
                 .all()
             ]
             ]
 
 
+    def get_global_action_functions(self) -> List[FunctionModel]:
+        with get_db() as db:
+            return [
+                FunctionModel.model_validate(function)
+                for function in db.query(Function)
+                .filter_by(type="action", is_active=True, is_global=True)
+                .all()
+            ]
+
     def get_function_valves_by_id(self, id: str) -> Optional[dict]:
     def get_function_valves_by_id(self, id: str) -> Optional[dict]:
         with get_db() as db:
         with get_db() as db:
 
 

+ 2 - 0
backend/apps/webui/utils.py

@@ -79,6 +79,8 @@ def load_function_module_by_id(function_id):
             return module.Pipe(), "pipe", frontmatter
             return module.Pipe(), "pipe", frontmatter
         elif hasattr(module, "Filter"):
         elif hasattr(module, "Filter"):
             return module.Filter(), "filter", frontmatter
             return module.Filter(), "filter", frontmatter
+        elif hasattr(module, "Action"):
+            return module.Action(), "action", frontmatter
         else:
         else:
             raise Exception("No Function class found")
             raise Exception("No Function class found")
     except Exception as e:
     except Exception as e:

+ 155 - 0
backend/main.py

@@ -926,6 +926,7 @@ webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
 
 
 
 
 async def get_all_models():
 async def get_all_models():
+    # TODO: Optimize this function
     pipe_models = []
     pipe_models = []
     openai_models = []
     openai_models = []
     ollama_models = []
     ollama_models = []
@@ -952,6 +953,14 @@ async def get_all_models():
 
 
     models = pipe_models + openai_models + ollama_models
     models = pipe_models + openai_models + ollama_models
 
 
+    global_action_ids = [
+        function.id for function in Functions.get_global_action_functions()
+    ]
+    enabled_action_ids = [
+        function.id
+        for function in Functions.get_functions_by_type("action", active_only=True)
+    ]
+
     custom_models = Models.get_all_models()
     custom_models = Models.get_all_models()
     for custom_model in custom_models:
     for custom_model in custom_models:
         if custom_model.base_model_id == None:
         if custom_model.base_model_id == None:
@@ -962,9 +971,32 @@ async def get_all_models():
                 ):
                 ):
                     model["name"] = custom_model.name
                     model["name"] = custom_model.name
                     model["info"] = custom_model.model_dump()
                     model["info"] = custom_model.model_dump()
+
+                    action_ids = [] + global_action_ids
+                    if "info" in model and "meta" in model["info"]:
+                        action_ids.extend(model["info"]["meta"].get("actionIds", []))
+                        action_ids = list(set(action_ids))
+                    action_ids = [
+                        action_id
+                        for action_id in action_ids
+                        if action_id in enabled_action_ids
+                    ]
+
+                    model["actions"] = [
+                        {
+                            "id": action_id,
+                            "name": Functions.get_function_by_id(action_id).name,
+                            "description": Functions.get_function_by_id(
+                                action_id
+                            ).meta.description,
+                        }
+                        for action_id in action_ids
+                    ]
+
         else:
         else:
             owned_by = "openai"
             owned_by = "openai"
             pipe = None
             pipe = None
+            actions = []
 
 
             for model in models:
             for model in models:
                 if (
                 if (
@@ -974,6 +1006,27 @@ async def get_all_models():
                     owned_by = model["owned_by"]
                     owned_by = model["owned_by"]
                     if "pipe" in model:
                     if "pipe" in model:
                         pipe = model["pipe"]
                         pipe = model["pipe"]
+
+                    action_ids = [] + global_action_ids
+                    if "info" in model and "meta" in model["info"]:
+                        action_ids.extend(model["info"]["meta"].get("actionIds", []))
+                        action_ids = list(set(action_ids))
+                    action_ids = [
+                        action_id
+                        for action_id in action_ids
+                        if action_id in enabled_action_ids
+                    ]
+
+                    actions = [
+                        {
+                            "id": action_id,
+                            "name": Functions.get_function_by_id(action_id).name,
+                            "description": Functions.get_function_by_id(
+                                action_id
+                            ).meta.description,
+                        }
+                        for action_id in action_ids
+                    ]
                     break
                     break
 
 
             models.append(
             models.append(
@@ -986,6 +1039,7 @@ async def get_all_models():
                     "info": custom_model.model_dump(),
                     "info": custom_model.model_dump(),
                     "preset": True,
                     "preset": True,
                     **({"pipe": pipe} if pipe is not None else {}),
                     **({"pipe": pipe} if pipe is not None else {}),
+                    "actions": actions,
                 }
                 }
             )
             )
 
 
@@ -1221,6 +1275,107 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
     return data
     return data
 
 
 
 
+@app.post("/api/chat/actions/{action_id}")
+async def chat_completed(
+    action_id: str, form_data: dict, user=Depends(get_verified_user)
+):
+    action = Functions.get_function_by_id(action_id)
+    if not action:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail="Action not found",
+        )
+    
+    data = form_data
+    model_id = data["model"]
+    if model_id not in app.state.MODELS:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail="Model not found",
+        )
+    model = app.state.MODELS[model_id]
+
+    __event_emitter__ = await get_event_emitter(
+        {
+            "chat_id": data["chat_id"],
+            "message_id": data["id"],
+            "session_id": data["session_id"],
+        }
+    )
+    __event_call__ = await get_event_call(
+        {
+            "chat_id": data["chat_id"],
+            "message_id": data["id"],
+            "session_id": data["session_id"],
+        }
+    )
+
+    if action_id in webui_app.state.FUNCTIONS:
+        function_module = webui_app.state.FUNCTIONS[action_id]
+    else:
+        function_module, _, _ = load_function_module_by_id(action_id)
+        webui_app.state.FUNCTIONS[action_id] = function_module
+
+    if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
+        valves = Functions.get_function_valves_by_id(action_id)
+        function_module.valves = function_module.Valves(**(valves if valves else {}))
+
+    if hasattr(function_module, "action"):
+        try:
+            action = function_module.action
+
+            # Get the signature of the function
+            sig = inspect.signature(action)
+            params = {"body": data}
+
+            # Extra parameters to be passed to the function
+            extra_params = {
+                "__model__": model,
+                "__id__": action_id,
+                "__event_emitter__": __event_emitter__,
+                "__event_call__": __event_call__,
+            }
+
+            # Add extra params in contained in function signature
+            for key, value in extra_params.items():
+                if key in sig.parameters:
+                    params[key] = value
+
+            if "__user__" in sig.parameters:
+                __user__ = {
+                    "id": user.id,
+                    "email": user.email,
+                    "name": user.name,
+                    "role": user.role,
+                }
+
+                try:
+                    if hasattr(function_module, "UserValves"):
+                        __user__["valves"] = function_module.UserValves(
+                            **Functions.get_user_valves_by_id_and_user_id(
+                                action_id, user.id
+                            )
+                        )
+                except Exception as e:
+                    print(e)
+
+                params = {**params, "__user__": __user__}
+
+            if inspect.iscoroutinefunction(action):
+                data = await action(**params)
+            else:
+                data = action(**params)
+
+        except Exception as e:
+            print(f"Error: {e}")
+            return JSONResponse(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                content={"detail": str(e)},
+            )
+
+    return data
+
+
 ##################################
 ##################################
 #
 #
 # Task Endpoints
 # Task Endpoints

+ 17 - 0
src/lib/components/chat/Messages/ResponseMessage.svelte

@@ -37,6 +37,7 @@
 	import CitationsModal from '$lib/components/chat/Messages/CitationsModal.svelte';
 	import CitationsModal from '$lib/components/chat/Messages/CitationsModal.svelte';
 	import Spinner from '$lib/components/common/Spinner.svelte';
 	import Spinner from '$lib/components/common/Spinner.svelte';
 	import WebSearchResults from './ResponseMessage/WebSearchResults.svelte';
 	import WebSearchResults from './ResponseMessage/WebSearchResults.svelte';
+	import Sparkles from '$lib/components/icons/Sparkles.svelte';
 
 
 	export let message;
 	export let message;
 	export let siblings;
 	export let siblings;
@@ -1020,6 +1021,22 @@
 														</svg>
 														</svg>
 													</button>
 													</button>
 												</Tooltip>
 												</Tooltip>
+
+												{#each model?.actions ?? [] as action}
+													<Tooltip content={action.name} placement="bottom">
+														<button
+															type="button"
+															class="{isLastMessage
+																? 'visible'
+																: 'invisible group-hover:visible'} p-1.5 hover:bg-black/5 dark:hover:bg-white/5 rounded-lg dark:hover:text-white hover:text-black transition regenerate-response-button"
+															on:click={() => {
+																console.log('action');
+															}}
+														>
+															<Sparkles strokeWidth="2.1" className="size-4" />
+														</button>
+													</Tooltip>
+												{/each}
 											{/if}
 											{/if}
 										{/if}
 										{/if}
 									{/if}
 									{/if}

+ 19 - 0
src/lib/components/icons/Sparkles.svelte

@@ -0,0 +1,19 @@
+<script lang="ts">
+	export let className = 'w-4 h-4';
+	export let strokeWidth = '1.5';
+</script>
+
+<svg
+	xmlns="http://www.w3.org/2000/svg"
+	fill="none"
+	viewBox="0 0 24 24"
+	stroke-width={strokeWidth}
+	stroke="currentColor"
+	class={className}
+>
+	<path
+		stroke-linecap="round"
+		stroke-linejoin="round"
+		d="M9.813 15.904 9 18.75l-.813-2.846a4.5 4.5 0 0 0-3.09-3.09L2.25 12l2.846-.813a4.5 4.5 0 0 0 3.09-3.09L9 5.25l.813 2.846a4.5 4.5 0 0 0 3.09 3.09L15.75 12l-2.846.813a4.5 4.5 0 0 0-3.09 3.09ZM18.259 8.715 18 9.75l-.259-1.035a3.375 3.375 0 0 0-2.455-2.456L14.25 6l1.036-.259a3.375 3.375 0 0 0 2.455-2.456L18 2.25l.259 1.035a3.375 3.375 0 0 0 2.456 2.456L21.75 6l-1.035.259a3.375 3.375 0 0 0-2.456 2.456ZM16.894 20.567 16.5 21.75l-.394-1.183a2.25 2.25 0 0 0-1.423-1.423L13.5 18.75l1.183-.394a2.25 2.25 0 0 0 1.423-1.423l.394-1.183.394 1.183a2.25 2.25 0 0 0 1.423 1.423l1.183.394-1.183.394a2.25 2.25 0 0 0-1.423 1.423Z"
+	/>
+</svg>

+ 8 - 3
src/lib/components/workspace/Functions.svelte

@@ -122,12 +122,17 @@
 
 
 		if (res) {
 		if (res) {
 			if (func.is_global) {
 			if (func.is_global) {
-				toast.success($i18n.t('Filter is now globally enabled'));
+				func.type === 'filter'
+					? toast.success($i18n.t('Filter is now globally enabled'))
+					: toast.success($i18n.t('Function is now globally enabled'));
 			} else {
 			} else {
-				toast.success($i18n.t('Filter is now globally disabled'));
+				func.type === 'filter'
+					? toast.success($i18n.t('Filter is now globally disabled'))
+					: toast.success($i18n.t('Function is now globally disabled'));
 			}
 			}
 
 
 			functions.set(await getFunctions(localStorage.token));
 			functions.set(await getFunctions(localStorage.token));
+			models.set(await getModels(localStorage.token));
 		}
 		}
 	};
 	};
 </script>
 </script>
@@ -294,7 +299,7 @@
 						showDeleteConfirm = true;
 						showDeleteConfirm = true;
 					}}
 					}}
 					toggleGlobalHandler={() => {
 					toggleGlobalHandler={() => {
-						if (func.type === 'filter') {
+						if (['filter', 'action'].includes(func.type)) {
 							toggleGlobalHandler(func);
 							toggleGlobalHandler(func);
 						}
 						}
 					}}
 					}}

+ 1 - 1
src/lib/components/workspace/Functions/FunctionMenu.svelte

@@ -48,7 +48,7 @@
 			align="start"
 			align="start"
 			transition={flyAndScale}
 			transition={flyAndScale}
 		>
 		>
-			{#if func.type === 'filter'}
+			{#if ['filter', 'action'].includes(func.type)}
 				<div
 				<div
 					class="flex gap-2 justify-between items-center px-3 py-2 text-sm font-medium cursor-pointerrounded-md"
 					class="flex gap-2 justify-between items-center px-3 py-2 text-sm font-medium cursor-pointerrounded-md"
 				>
 				>