فهرست منبع

add get_configured_tools

Michael Poluektov 8 ماه پیش
والد
کامیت
6df6170c44
1فایلهای تغییر یافته به همراه81 افزوده شده و 18 حذف شده
  1. 81 18
      backend/main.py

+ 81 - 18
backend/main.py

@@ -51,13 +51,13 @@ from apps.webui.internal.db import Session
 
 
 
 
 from pydantic import BaseModel
 from pydantic import BaseModel
-from typing import Optional
+from typing import Optional, Callable, Awaitable
 
 
 from apps.webui.models.auths import Auths
 from apps.webui.models.auths import Auths
 from apps.webui.models.models import Models
 from apps.webui.models.models import Models
 from apps.webui.models.tools import Tools
 from apps.webui.models.tools import Tools
 from apps.webui.models.functions import Functions
 from apps.webui.models.functions import Functions
-from apps.webui.models.users import Users, User
+from apps.webui.models.users import Users, UserModel
 
 
 from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
 from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
 
 
@@ -356,6 +356,7 @@ async def get_tool_call_response(
         return None, None, False
         return None, None, False
 
 
     tools_specs = json.dumps(tool.specs, indent=2)
     tools_specs = json.dumps(tool.specs, indent=2)
+    log.debug(f"{tool.specs=}")
     content = tool_calling_generation_template(template, tools_specs)
     content = tool_calling_generation_template(template, tools_specs)
     payload = get_tool_call_payload(messages, task_model_id, content)
     payload = get_tool_call_payload(messages, task_model_id, content)
 
 
@@ -492,14 +493,81 @@ async def chat_completion_inlets_handler(body, model, extra_params):
     return body, {}
     return body, {}
 
 
 
 
+def get_tool_with_custom_params(
+    tool: Callable, custom_params: dict
+) -> Callable[..., Awaitable]:
+    sig = inspect.signature(tool)
+    extra_params = {
+        key: value for key, value in custom_params.items() if key in sig.parameters
+    }
+    is_coroutine = inspect.iscoroutinefunction(tool)
+
+    async def new_tool(**kwargs):
+        extra_kwargs = kwargs | extra_params
+        if is_coroutine:
+            return await tool(**extra_kwargs)
+        return tool(**extra_kwargs)
+
+    return new_tool
+
+
+def get_configured_tools(
+    tool_ids: list[str], extra_params: dict, user: UserModel
+) -> dict[str, dict]:
+    tools = {}
+    for tool_id in tool_ids:
+        toolkit = Tools.get_tool_by_id(tool_id)
+        if toolkit is None:
+            continue
+
+        module = webui_app.state.TOOLS.get(tool_id, None)
+        if module is None:
+            module, _ = load_toolkit_module_by_id(tool_id)
+            webui_app.state.TOOLS[tool_id] = module
+
+        more_params = {"__id__": tool_id}
+        custom_params = more_params | extra_params
+        has_citation = hasattr(module, "citation") and module.citation
+        handles_files = hasattr(module, "file_handler") and module.file_handler
+        if hasattr(module, "valves") and hasattr(module, "Valves"):
+            valves = Tools.get_tool_valves_by_id(tool_id) or {}
+            module.valves = module.Valves(**valves)
+
+        if hasattr(module, "UserValves"):
+            custom_params["__user__"]["valves"] = module.UserValves(  # type: ignore
+                **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
+            )
+
+        for spec in toolkit.specs:
+            name = spec["name"]
+            callable = getattr(module, name)
+            # convert to function that takes only model params and inserts custom params
+            custom_callable = get_tool_with_custom_params(callable, custom_params)
+
+            tool_dict = {
+                "spec": spec,
+                "citation": has_citation,
+                "file_handler": handles_files,
+                "toolkit_module": module,
+                "callable": custom_callable,
+            }
+            if name in tools:
+                log.warning(f"Tool {name} already exists in another toolkit!")
+                mod_name = tools[name]["toolkit_module"].__name__
+                log.warning(f"Collision between {toolkit} and {mod_name}.")
+                log.warning(f"Discarding {toolkit}.{name}")
+            else:
+                tools[name] = tool_dict
+
+    return tools
+
+
 async def chat_completion_tools_handler(
 async def chat_completion_tools_handler(
-    body: dict, user: User, extra_params: dict
+    body: dict, user: UserModel, extra_params: dict
 ) -> tuple[dict, dict]:
 ) -> tuple[dict, dict]:
-    skip_files = None
-
+    skip_files = False
     contexts = []
     contexts = []
-    citations = None
-
+    citations = []
     task_model_id = get_task_model_id(body["model"])
     task_model_id = get_task_model_id(body["model"])
 
 
     # If tool_ids field is present, call the functions
     # If tool_ids field is present, call the functions
@@ -507,6 +575,7 @@ async def chat_completion_tools_handler(
         return body, {}
         return body, {}
 
 
     log.debug(f"tool_ids: {body['tool_ids']}")
     log.debug(f"tool_ids: {body['tool_ids']}")
+    log.info(f"{get_configured_tools(body['tool_ids'], extra_params, user)=}")
     kwargs = {
     kwargs = {
         "messages": body["messages"],
         "messages": body["messages"],
         "files": body.get("files", []),
         "files": body.get("files", []),
@@ -515,6 +584,7 @@ async def chat_completion_tools_handler(
         "user": user,
         "user": user,
         "extra_params": extra_params,
         "extra_params": extra_params,
     }
     }
+
     for tool_id in body["tool_ids"]:
     for tool_id in body["tool_ids"]:
         log.debug(f"{tool_id=}")
         log.debug(f"{tool_id=}")
         try:
         try:
@@ -526,10 +596,7 @@ async def chat_completion_tools_handler(
                 contexts.append(response)
                 contexts.append(response)
 
 
             if citation:
             if citation:
-                if citations is None:
-                    citations = [citation]
-                else:
-                    citations.append(citation)
+                citations.append(citation)
 
 
             if file_handler:
             if file_handler:
                 skip_files = True
                 skip_files = True
@@ -540,14 +607,10 @@ async def chat_completion_tools_handler(
     del body["tool_ids"]
     del body["tool_ids"]
     log.debug(f"tool_contexts: {contexts}")
     log.debug(f"tool_contexts: {contexts}")
 
 
-    if skip_files:
-        if "files" in body:
-            del body["files"]
+    if skip_files and "files" in body:
+        del body["files"]
 
 
-    return body, {
-        **({"contexts": contexts} if contexts is not None else {}),
-        **({"citations": citations} if citations is not None else {}),
-    }
+    return body, {"contexts": contexts, "citations": citations}
 
 
 
 
 async def chat_completion_files_handler(body):
 async def chat_completion_files_handler(body):