|
@@ -926,6 +926,7 @@ webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
|
|
|
|
|
|
|
|
|
|
async def get_all_models():
|
|
async def get_all_models():
|
|
|
|
+ # TODO: Optimize this function
|
|
pipe_models = []
|
|
pipe_models = []
|
|
openai_models = []
|
|
openai_models = []
|
|
ollama_models = []
|
|
ollama_models = []
|
|
@@ -952,6 +953,14 @@ async def get_all_models():
|
|
|
|
|
|
models = pipe_models + openai_models + ollama_models
|
|
models = pipe_models + openai_models + ollama_models
|
|
|
|
|
|
|
|
+ global_action_ids = [
|
|
|
|
+ function.id for function in Functions.get_global_action_functions()
|
|
|
|
+ ]
|
|
|
|
+ enabled_action_ids = [
|
|
|
|
+ function.id
|
|
|
|
+ for function in Functions.get_functions_by_type("action", active_only=True)
|
|
|
|
+ ]
|
|
|
|
+
|
|
custom_models = Models.get_all_models()
|
|
custom_models = Models.get_all_models()
|
|
for custom_model in custom_models:
|
|
for custom_model in custom_models:
|
|
if custom_model.base_model_id == None:
|
|
if custom_model.base_model_id == None:
|
|
@@ -962,9 +971,32 @@ async def get_all_models():
|
|
):
|
|
):
|
|
model["name"] = custom_model.name
|
|
model["name"] = custom_model.name
|
|
model["info"] = custom_model.model_dump()
|
|
model["info"] = custom_model.model_dump()
|
|
|
|
+
|
|
|
|
+ action_ids = [] + global_action_ids
|
|
|
|
+ if "info" in model and "meta" in model["info"]:
|
|
|
|
+ action_ids.extend(model["info"]["meta"].get("actionIds", []))
|
|
|
|
+ action_ids = list(set(action_ids))
|
|
|
|
+ action_ids = [
|
|
|
|
+ action_id
|
|
|
|
+ for action_id in action_ids
|
|
|
|
+ if action_id in enabled_action_ids
|
|
|
|
+ ]
|
|
|
|
+
|
|
|
|
+ model["actions"] = [
|
|
|
|
+ {
|
|
|
|
+ "id": action_id,
|
|
|
|
+ "name": Functions.get_function_by_id(action_id).name,
|
|
|
|
+ "description": Functions.get_function_by_id(
|
|
|
|
+ action_id
|
|
|
|
+ ).meta.description,
|
|
|
|
+ }
|
|
|
|
+ for action_id in action_ids
|
|
|
|
+ ]
|
|
|
|
+
|
|
else:
|
|
else:
|
|
owned_by = "openai"
|
|
owned_by = "openai"
|
|
pipe = None
|
|
pipe = None
|
|
|
|
+ actions = []
|
|
|
|
|
|
for model in models:
|
|
for model in models:
|
|
if (
|
|
if (
|
|
@@ -974,6 +1006,27 @@ async def get_all_models():
|
|
owned_by = model["owned_by"]
|
|
owned_by = model["owned_by"]
|
|
if "pipe" in model:
|
|
if "pipe" in model:
|
|
pipe = model["pipe"]
|
|
pipe = model["pipe"]
|
|
|
|
+
|
|
|
|
+ action_ids = [] + global_action_ids
|
|
|
|
+ if "info" in model and "meta" in model["info"]:
|
|
|
|
+ action_ids.extend(model["info"]["meta"].get("actionIds", []))
|
|
|
|
+ action_ids = list(set(action_ids))
|
|
|
|
+ action_ids = [
|
|
|
|
+ action_id
|
|
|
|
+ for action_id in action_ids
|
|
|
|
+ if action_id in enabled_action_ids
|
|
|
|
+ ]
|
|
|
|
+
|
|
|
|
+ actions = [
|
|
|
|
+ {
|
|
|
|
+ "id": action_id,
|
|
|
|
+ "name": Functions.get_function_by_id(action_id).name,
|
|
|
|
+ "description": Functions.get_function_by_id(
|
|
|
|
+ action_id
|
|
|
|
+ ).meta.description,
|
|
|
|
+ }
|
|
|
|
+ for action_id in action_ids
|
|
|
|
+ ]
|
|
break
|
|
break
|
|
|
|
|
|
models.append(
|
|
models.append(
|
|
@@ -986,6 +1039,7 @@ async def get_all_models():
|
|
"info": custom_model.model_dump(),
|
|
"info": custom_model.model_dump(),
|
|
"preset": True,
|
|
"preset": True,
|
|
**({"pipe": pipe} if pipe is not None else {}),
|
|
**({"pipe": pipe} if pipe is not None else {}),
|
|
|
|
+ "actions": actions,
|
|
}
|
|
}
|
|
)
|
|
)
|
|
|
|
|
|
@@ -1221,6 +1275,107 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
|
|
return data
|
|
return data
|
|
|
|
|
|
|
|
|
|
|
|
+@app.post("/api/chat/actions/{action_id}")
|
|
|
|
+async def chat_completed(
|
|
|
|
+ action_id: str, form_data: dict, user=Depends(get_verified_user)
|
|
|
|
+):
|
|
|
|
+ action = Functions.get_function_by_id(action_id)
|
|
|
|
+ if not action:
|
|
|
|
+ raise HTTPException(
|
|
|
|
+ status_code=status.HTTP_404_NOT_FOUND,
|
|
|
|
+ detail="Action not found",
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ data = form_data
|
|
|
|
+ model_id = data["model"]
|
|
|
|
+ if model_id not in app.state.MODELS:
|
|
|
|
+ raise HTTPException(
|
|
|
|
+ status_code=status.HTTP_404_NOT_FOUND,
|
|
|
|
+ detail="Model not found",
|
|
|
|
+ )
|
|
|
|
+ model = app.state.MODELS[model_id]
|
|
|
|
+
|
|
|
|
+ __event_emitter__ = await get_event_emitter(
|
|
|
|
+ {
|
|
|
|
+ "chat_id": data["chat_id"],
|
|
|
|
+ "message_id": data["id"],
|
|
|
|
+ "session_id": data["session_id"],
|
|
|
|
+ }
|
|
|
|
+ )
|
|
|
|
+ __event_call__ = await get_event_call(
|
|
|
|
+ {
|
|
|
|
+ "chat_id": data["chat_id"],
|
|
|
|
+ "message_id": data["id"],
|
|
|
|
+ "session_id": data["session_id"],
|
|
|
|
+ }
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ if action_id in webui_app.state.FUNCTIONS:
|
|
|
|
+ function_module = webui_app.state.FUNCTIONS[action_id]
|
|
|
|
+ else:
|
|
|
|
+ function_module, _, _ = load_function_module_by_id(action_id)
|
|
|
|
+ webui_app.state.FUNCTIONS[action_id] = function_module
|
|
|
|
+
|
|
|
|
+ if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
|
|
|
+ valves = Functions.get_function_valves_by_id(action_id)
|
|
|
|
+ function_module.valves = function_module.Valves(**(valves if valves else {}))
|
|
|
|
+
|
|
|
|
+ if hasattr(function_module, "action"):
|
|
|
|
+ try:
|
|
|
|
+ action = function_module.action
|
|
|
|
+
|
|
|
|
+ # Get the signature of the function
|
|
|
|
+ sig = inspect.signature(action)
|
|
|
|
+ params = {"body": data}
|
|
|
|
+
|
|
|
|
+ # Extra parameters to be passed to the function
|
|
|
|
+ extra_params = {
|
|
|
|
+ "__model__": model,
|
|
|
|
+ "__id__": action_id,
|
|
|
|
+ "__event_emitter__": __event_emitter__,
|
|
|
|
+ "__event_call__": __event_call__,
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ # Add extra params in contained in function signature
|
|
|
|
+ for key, value in extra_params.items():
|
|
|
|
+ if key in sig.parameters:
|
|
|
|
+ params[key] = value
|
|
|
|
+
|
|
|
|
+ if "__user__" in sig.parameters:
|
|
|
|
+ __user__ = {
|
|
|
|
+ "id": user.id,
|
|
|
|
+ "email": user.email,
|
|
|
|
+ "name": user.name,
|
|
|
|
+ "role": user.role,
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ try:
|
|
|
|
+ if hasattr(function_module, "UserValves"):
|
|
|
|
+ __user__["valves"] = function_module.UserValves(
|
|
|
|
+ **Functions.get_user_valves_by_id_and_user_id(
|
|
|
|
+ action_id, user.id
|
|
|
|
+ )
|
|
|
|
+ )
|
|
|
|
+ except Exception as e:
|
|
|
|
+ print(e)
|
|
|
|
+
|
|
|
|
+ params = {**params, "__user__": __user__}
|
|
|
|
+
|
|
|
|
+ if inspect.iscoroutinefunction(action):
|
|
|
|
+ data = await action(**params)
|
|
|
|
+ else:
|
|
|
|
+ data = action(**params)
|
|
|
|
+
|
|
|
|
+ except Exception as e:
|
|
|
|
+ print(f"Error: {e}")
|
|
|
|
+ return JSONResponse(
|
|
|
|
+ status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
|
+ content={"detail": str(e)},
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ return data
|
|
|
|
+
|
|
|
|
+
|
|
##################################
|
|
##################################
|
|
#
|
|
#
|
|
# Task Endpoints
|
|
# Task Endpoints
|