|
@@ -100,7 +100,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
|
|
|
|
|
|
|
|
async def chat_completion_tools_handler(
|
|
|
- request: Request, body: dict, user: UserModel, models, tools
|
|
|
+ request: Request, body: dict, extra_params: dict, user: UserModel, models, tools
|
|
|
) -> tuple[dict, dict]:
|
|
|
async def get_content_from_response(response) -> Optional[str]:
|
|
|
content = None
|
|
@@ -135,6 +135,9 @@ async def chat_completion_tools_handler(
|
|
|
"metadata": {"task": str(TASKS.FUNCTION_CALLING)},
|
|
|
}
|
|
|
|
|
|
+ event_caller = extra_params["__event_call__"]
|
|
|
+ metadata = extra_params["__metadata__"]
|
|
|
+
|
|
|
task_model_id = get_task_model_id(
|
|
|
body["model"],
|
|
|
request.app.state.config.TASK_MODEL,
|
|
@@ -189,17 +192,33 @@ async def chat_completion_tools_handler(
|
|
|
tool_function_params = tool_call.get("parameters", {})
|
|
|
|
|
|
try:
|
|
|
- spec = tools[tool_function_name].get("spec", {})
|
|
|
+ tool = tools[tool_function_name]
|
|
|
+
|
|
|
+ spec = tool.get("spec", {})
|
|
|
allowed_params = (
|
|
|
spec.get("parameters", {}).get("properties", {}).keys()
|
|
|
)
|
|
|
- tool_function = tools[tool_function_name]["callable"]
|
|
|
+ tool_function = tool["callable"]
|
|
|
tool_function_params = {
|
|
|
k: v
|
|
|
for k, v in tool_function_params.items()
|
|
|
if k in allowed_params
|
|
|
}
|
|
|
- tool_output = await tool_function(**tool_function_params)
|
|
|
+
|
|
|
+ if tool.get("direct", False):
|
|
|
+ tool_output = await tool_function(**tool_function_params)
|
|
|
+ else:
|
|
|
+ tool_output = await event_caller(
|
|
|
+ {
|
|
|
+ "type": "execute:tool",
|
|
|
+ "data": {
|
|
|
+ "id": str(uuid4()),
|
|
|
+ "tool": tool,
|
|
|
+ "params": tool_function_params,
|
|
|
+ "session_id": metadata.get("session_id", None),
|
|
|
+ },
|
|
|
+ }
|
|
|
+ )
|
|
|
|
|
|
except Exception as e:
|
|
|
tool_output = str(e)
|
|
@@ -764,12 +783,18 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|
|
}
|
|
|
form_data["metadata"] = metadata
|
|
|
|
|
|
+ # Server side tools
|
|
|
tool_ids = metadata.get("tool_ids", None)
|
|
|
+ # Client side tools
|
|
|
+ tool_specs = form_data.get("tool_specs", None)
|
|
|
+
|
|
|
log.debug(f"{tool_ids=}")
|
|
|
+ log.debug(f"{tool_specs=}")
|
|
|
+
|
|
|
+ tools_dict = {}
|
|
|
|
|
|
if tool_ids:
|
|
|
- # If tool_ids field is present, then get the tools
|
|
|
- tools = get_tools(
|
|
|
+ tools_dict = get_tools(
|
|
|
request,
|
|
|
tool_ids,
|
|
|
user,
|
|
@@ -780,20 +805,30 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|
|
"__files__": metadata.get("files", []),
|
|
|
},
|
|
|
)
|
|
|
- log.info(f"{tools=}")
|
|
|
+ log.info(f"{tools_dict=}")
|
|
|
+
|
|
|
+ if tool_specs:
|
|
|
+ for tool in tool_specs:
|
|
|
+ callable = tool.pop("callable", None)
|
|
|
+ tools_dict[tool["name"]] = {
|
|
|
+ "direct": True,
|
|
|
+ "callable": callable,
|
|
|
+ "spec": tool,
|
|
|
+ }
|
|
|
|
|
|
+ if tools_dict:
|
|
|
if metadata.get("function_calling") == "native":
|
|
|
# If the function calling is native, then call the tools function calling handler
|
|
|
- metadata["tools"] = tools
|
|
|
+ metadata["tools"] = tools_dict
|
|
|
form_data["tools"] = [
|
|
|
{"type": "function", "function": tool.get("spec", {})}
|
|
|
- for tool in tools.values()
|
|
|
+ for tool in tools_dict.values()
|
|
|
]
|
|
|
else:
|
|
|
# If the function calling is not native, then call the tools function calling handler
|
|
|
try:
|
|
|
form_data, flags = await chat_completion_tools_handler(
|
|
|
- request, form_data, user, models, tools
|
|
|
+ request, form_data, extra_params, user, models, tools_dict
|
|
|
)
|
|
|
sources.extend(flags.get("sources", []))
|
|
|
|
|
@@ -1774,9 +1809,25 @@ async def process_chat_response(
|
|
|
for k, v in tool_function_params.items()
|
|
|
if k in allowed_params
|
|
|
}
|
|
|
- tool_result = await tool_function(
|
|
|
- **tool_function_params
|
|
|
- )
|
|
|
+
|
|
|
+ if tool.get("direct", False):
|
|
|
+ tool_result = await tool_function(
|
|
|
+ **tool_function_params
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ tool_result = await event_caller(
|
|
|
+ {
|
|
|
+ "type": "execute:tool",
|
|
|
+ "data": {
|
|
|
+ "id": str(uuid4()),
|
|
|
+ "tool": tool,
|
|
|
+ "params": tool_function_params,
|
|
|
+ "session_id": metadata.get(
|
|
|
+ "session_id", None
|
|
|
+ ),
|
|
|
+ },
|
|
|
+ }
|
|
|
+ )
|
|
|
except Exception as e:
|
|
|
tool_result = str(e)
|
|
|
|