浏览代码

tool calling refactor

Michael Poluektov 8 月之前
父节点
当前提交
fdc89cbcee
共有 1 个文件被更改,包括 59 次插入139 次删除
  1. 59 139
      backend/main.py

+ 59 - 139
backend/main.py

@@ -298,30 +298,6 @@ async def get_content_from_response(response) -> Optional[str]:
     return content
 
 
-async def call_tool_from_completion(
-    result: dict, extra_params: dict, toolkit_module
-) -> Optional[str]:
-    if "name" not in result:
-        return None
-
-    tool = getattr(toolkit_module, result["name"])
-    try:
-        # Get the signature of the function
-        sig = inspect.signature(tool)
-        params = result["parameters"]
-        for key, value in extra_params.items():
-            if key in sig.parameters:
-                params[key] = value
-
-        if inspect.iscoroutinefunction(tool):
-            return await tool(**params)
-        else:
-            return tool(**params)
-    except Exception as e:
-        print(f"Error: {e}")
-        return None
-
-
 def get_tool_call_payload(messages, task_model_id, content):
     user_message = get_last_user_message(messages)
     history = "\n".join(
@@ -342,90 +318,6 @@ def get_tool_call_payload(messages, task_model_id, content):
     }
 
 
-async def get_tool_call_response(
-    messages, files, tool_id, template, task_model_id, user, extra_params
-) -> tuple[Optional[str], Optional[dict], bool]:
-    """
-    return: tuple of (function_result, citation, file_handler) where
-    - function_result: Optional[str] is the result of the tool call if successful
-    - citation: Optional[dict] is the citation object if the tool has citation
-    - file_handler: bool, True if tool handles files
-    """
-    tool = Tools.get_tool_by_id(tool_id)
-    if tool is None:
-        return None, None, False
-
-    tools_specs = json.dumps(tool.specs, indent=2)
-    log.debug(f"{tool.specs=}")
-    content = tool_calling_generation_template(template, tools_specs)
-    payload = get_tool_call_payload(messages, task_model_id, content)
-
-    try:
-        payload = filter_pipeline(payload, user)
-    except Exception as e:
-        raise e
-
-    if tool_id in webui_app.state.TOOLS:
-        toolkit_module = webui_app.state.TOOLS[tool_id]
-    else:
-        toolkit_module, _ = load_toolkit_module_by_id(tool_id)
-        webui_app.state.TOOLS[tool_id] = toolkit_module
-
-    custom_params = {
-        **extra_params,
-        "__model__": app.state.MODELS[task_model_id],
-        "__id__": tool_id,
-        "__messages__": messages,
-        "__files__": files,
-    }
-    try:
-        if hasattr(toolkit_module, "UserValves"):
-            custom_params["__user__"]["valves"] = toolkit_module.UserValves(
-                **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
-            )
-
-    except Exception as e:
-        print(e)
-
-    file_handler = hasattr(toolkit_module, "file_handler")
-
-    if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"):
-        valves = Tools.get_tool_valves_by_id(tool_id)
-        toolkit_module.valves = toolkit_module.Valves(**(valves if valves else {}))
-
-    try:
-        response = await generate_chat_completions(form_data=payload, user=user)
-        content = await get_content_from_response(response)
-
-        if content is None:
-            return None, None, False
-
-        # Parse the function response
-        log.debug(f"content: {content}")
-        result = json.loads(content)
-
-        function_result = await call_tool_from_completion(
-            result, custom_params, toolkit_module
-        )
-
-        if hasattr(toolkit_module, "citation") and toolkit_module.citation:
-            citation = {
-                "source": {"name": f"TOOL:{tool.name}/{result['name']}"},
-                "document": [function_result],
-                "metadata": [{"source": result["name"]}],
-            }
-        else:
-            citation = None
-
-        # Add the function result to the system prompt
-        if function_result is not None:
-            return function_result, citation, file_handler
-    except Exception as e:
-        print(f"Error: {e}")
-
-    return None, None, False
-
-
 async def chat_completion_inlets_handler(body, model, extra_params):
     skip_files = None
 
@@ -511,6 +403,7 @@ def get_tool_with_custom_params(
     return new_tool
 
 
+# Mutation on extra_params
 def get_configured_tools(
     tool_ids: list[str], extra_params: dict, user: UserModel
 ) -> dict[str, dict]:
@@ -525,8 +418,7 @@ def get_configured_tools(
             module, _ = load_toolkit_module_by_id(tool_id)
             webui_app.state.TOOLS[tool_id] = module
 
-        more_params = {"__id__": tool_id}
-        custom_params = more_params | extra_params
+        extra_params["__id__"] = tool_id
         has_citation = hasattr(module, "citation") and module.citation
         handles_files = hasattr(module, "file_handler") and module.file_handler
         if hasattr(module, "valves") and hasattr(module, "Valves"):
@@ -534,27 +426,27 @@ def get_configured_tools(
             module.valves = module.Valves(**valves)
 
         if hasattr(module, "UserValves"):
-            custom_params["__user__"]["valves"] = module.UserValves(  # type: ignore
+            extra_params["__user__"]["valves"] = module.UserValves(  # type: ignore
                 **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
             )
 
         for spec in toolkit.specs:
             name = spec["name"]
             callable = getattr(module, name)
+
             # convert to function that takes only model params and inserts custom params
-            custom_callable = get_tool_with_custom_params(callable, custom_params)
+            custom_callable = get_tool_with_custom_params(callable, extra_params)
 
             tool_dict = {
                 "spec": spec,
                 "citation": has_citation,
                 "file_handler": handles_files,
-                "toolkit_module": module,
+                "toolkit_id": tool_id,
                 "callable": custom_callable,
             }
             if name in tools:
                 log.warning(f"Tool {name} already exists in another toolkit!")
-                mod_name = tools[name]["toolkit_module"].__name__
-                log.warning(f"Collision between {toolkit} and {mod_name}.")
+                log.warning(f"Collision between {toolkit} and {tool_id}.")
                 log.warning(f"Discarding {toolkit}.{name}")
             else:
                 tools[name] = tool_dict
@@ -571,40 +463,68 @@ async def chat_completion_tools_handler(
     task_model_id = get_task_model_id(body["model"])
 
     # If tool_ids field is present, call the functions
-    if "tool_ids" not in body:
+    tool_ids = body.pop("tool_ids", None)
+    if not tool_ids:
         return body, {}
 
-    log.debug(f"tool_ids: {body['tool_ids']}")
-    log.info(f"{get_configured_tools(body['tool_ids'], extra_params, user)=}")
-    kwargs = {
-        "messages": body["messages"],
-        "files": body.get("files", []),
-        "template": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
-        "task_model_id": task_model_id,
-        "user": user,
-        "extra_params": extra_params,
+    log.debug(f"{tool_ids=}")
+    custom_params = {
+        **extra_params,
+        "__model__": app.state.MODELS[task_model_id],
+        "__messages__": body["messages"],
+        "__files__": body.get("files", []),
     }
+    configured_tools = get_configured_tools(tool_ids, custom_params, user)
 
-    for tool_id in body["tool_ids"]:
-        log.debug(f"{tool_id=}")
-        try:
-            response, citation, file_handler = await get_tool_call_response(
-                tool_id=tool_id, **kwargs
-            )
+    log.info(f"{configured_tools=}")
 
-            if isinstance(response, str):
-                contexts.append(response)
+    specs = [tool["spec"] for tool in configured_tools.values()]
+    tools_specs = json.dumps(specs)
+    template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
+    content = tool_calling_generation_template(template, tools_specs)
+    payload = get_tool_call_payload(body["messages"], task_model_id, content)
+    try:
+        payload = filter_pipeline(payload, user)
+    except Exception as e:
+        raise e
 
-            if citation:
-                citations.append(citation)
+    try:
+        response = await generate_chat_completions(form_data=payload, user=user)
+        log.debug(f"{response=}")
+        content = await get_content_from_response(response)
+        log.debug(f"{content=}")
+        if content is None:
+            return body, {}
 
-            if file_handler:
-                skip_files = True
+        result = json.loads(content)
+        tool_name = result.get("name", None)
+        if tool_name not in configured_tools:
+            return body, {}
 
+        tool_params = result.get("parameters", {})
+        toolkit_id = configured_tools[tool_name]["toolkit_id"]
+        try:
+            tool_output = await configured_tools[tool_name]["callable"](**tool_params)
         except Exception as e:
-            log.exception(f"Error: {e}")
+            tool_output = str(e)
+        if configured_tools[tool_name]["citation"]:
+            citations.append(
+                {
+                    "source": {"name": f"TOOL:{toolkit_id}/{tool_name}"},
+                    "document": [tool_output],
+                    "metadata": [{"source": tool_name}],
+                }
+            )
+        if configured_tools[tool_name]["file_handler"]:
+            skip_files = True
+
+        if isinstance(tool_output, str):
+            contexts.append(tool_output)
+
+    except Exception as e:
+        print(f"Error: {e}")
+        content = None
 
-    del body["tool_ids"]
     log.debug(f"tool_contexts: {contexts}")
 
     if skip_files and "files" in body: