Browse Source

add get_configured_tools

Michael Poluektov 8 months ago
parent
commit
6df6170c44
1 changed files with 81 additions and 18 deletions
  1. 81 18
      backend/main.py

+ 81 - 18
backend/main.py

@@ -51,13 +51,13 @@ from apps.webui.internal.db import Session
 
 
 from pydantic import BaseModel
-from typing import Optional
+from typing import Optional, Callable, Awaitable
 
 from apps.webui.models.auths import Auths
 from apps.webui.models.models import Models
 from apps.webui.models.tools import Tools
 from apps.webui.models.functions import Functions
-from apps.webui.models.users import Users, User
+from apps.webui.models.users import Users, UserModel
 
 from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
 
@@ -356,6 +356,7 @@ async def get_tool_call_response(
         return None, None, False
 
     tools_specs = json.dumps(tool.specs, indent=2)
+    log.debug(f"{tool.specs=}")
     content = tool_calling_generation_template(template, tools_specs)
     payload = get_tool_call_payload(messages, task_model_id, content)
 
@@ -492,14 +493,81 @@ async def chat_completion_inlets_handler(body, model, extra_params):
     return body, {}
 
 
+def get_tool_with_custom_params(
+    tool: Callable, custom_params: dict
+) -> Callable[..., Awaitable]:
+    sig = inspect.signature(tool)
+    extra_params = {
+        key: value for key, value in custom_params.items() if key in sig.parameters
+    }
+    is_coroutine = inspect.iscoroutinefunction(tool)
+
+    async def new_tool(**kwargs):
+        extra_kwargs = kwargs | extra_params
+        if is_coroutine:
+            return await tool(**extra_kwargs)
+        return tool(**extra_kwargs)
+
+    return new_tool
+
+
+def get_configured_tools(
+    tool_ids: list[str], extra_params: dict, user: UserModel
+) -> dict[str, dict]:
+    tools = {}
+    for tool_id in tool_ids:
+        toolkit = Tools.get_tool_by_id(tool_id)
+        if toolkit is None:
+            continue
+
+        module = webui_app.state.TOOLS.get(tool_id, None)
+        if module is None:
+            module, _ = load_toolkit_module_by_id(tool_id)
+            webui_app.state.TOOLS[tool_id] = module
+
+        more_params = {"__id__": tool_id}
+        custom_params = more_params | extra_params
+        has_citation = hasattr(module, "citation") and module.citation
+        handles_files = hasattr(module, "file_handler") and module.file_handler
+        if hasattr(module, "valves") and hasattr(module, "Valves"):
+            valves = Tools.get_tool_valves_by_id(tool_id) or {}
+            module.valves = module.Valves(**valves)
+
+        if hasattr(module, "UserValves"):
+            custom_params["__user__"]["valves"] = module.UserValves(  # type: ignore
+                **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
+            )
+
+        for spec in toolkit.specs:
+            name = spec["name"]
+            callable = getattr(module, name)
+            # convert to function that takes only model params and inserts custom params
+            custom_callable = get_tool_with_custom_params(callable, custom_params)
+
+            tool_dict = {
+                "spec": spec,
+                "citation": has_citation,
+                "file_handler": handles_files,
+                "toolkit_module": module,
+                "callable": custom_callable,
+            }
+            if name in tools:
+                log.warning(f"Tool {name} already exists in another toolkit!")
+                mod_name = tools[name]["toolkit_module"].__name__
+                log.warning(f"Collision between {toolkit} and {mod_name}.")
+                log.warning(f"Discarding {toolkit}.{name}")
+            else:
+                tools[name] = tool_dict
+
+    return tools
+
+
 async def chat_completion_tools_handler(
-    body: dict, user: User, extra_params: dict
+    body: dict, user: UserModel, extra_params: dict
 ) -> tuple[dict, dict]:
-    skip_files = None
-
+    skip_files = False
     contexts = []
-    citations = None
-
+    citations = []
     task_model_id = get_task_model_id(body["model"])
 
     # If tool_ids field is present, call the functions
@@ -507,6 +575,7 @@ async def chat_completion_tools_handler(
         return body, {}
 
     log.debug(f"tool_ids: {body['tool_ids']}")
+    log.info(f"{get_configured_tools(body['tool_ids'], extra_params, user)=}")
     kwargs = {
         "messages": body["messages"],
         "files": body.get("files", []),
@@ -515,6 +584,7 @@ async def chat_completion_tools_handler(
         "user": user,
         "extra_params": extra_params,
     }
+
     for tool_id in body["tool_ids"]:
         log.debug(f"{tool_id=}")
         try:
@@ -526,10 +596,7 @@ async def chat_completion_tools_handler(
                 contexts.append(response)
 
             if citation:
-                if citations is None:
-                    citations = [citation]
-                else:
-                    citations.append(citation)
+                citations.append(citation)
 
             if file_handler:
                 skip_files = True
@@ -540,14 +607,10 @@ async def chat_completion_tools_handler(
     del body["tool_ids"]
     log.debug(f"tool_contexts: {contexts}")
 
-    if skip_files:
-        if "files" in body:
-            del body["files"]
+    if skip_files and "files" in body:
+        del body["files"]
 
-    return body, {
-        **({"contexts": contexts} if contexts is not None else {}),
-        **({"citations": citations} if citations is not None else {}),
-    }
+    return body, {"contexts": contexts, "citations": citations}
 
 
 async def chat_completion_files_handler(body):