瀏覽代碼

refac: tools

Timothy Jaeryang Baek 1 月之前
父節點
當前提交
5a7efad59c
共有 3 個文件被更改,包括 86 次插入13 次删除
  1. 64 13
      backend/open_webui/utils/middleware.py
  2. 3 0
      backend/open_webui/utils/tools.py
  3. 19 0
      src/routes/+layout.svelte

+ 64 - 13
backend/open_webui/utils/middleware.py

@@ -100,7 +100,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
 
 
 async def chat_completion_tools_handler(
-    request: Request, body: dict, user: UserModel, models, tools
+    request: Request, body: dict, extra_params: dict, user: UserModel, models, tools
 ) -> tuple[dict, dict]:
     async def get_content_from_response(response) -> Optional[str]:
         content = None
@@ -135,6 +135,9 @@ async def chat_completion_tools_handler(
             "metadata": {"task": str(TASKS.FUNCTION_CALLING)},
         }
 
+    event_caller = extra_params["__event_call__"]
+    metadata = extra_params["__metadata__"]
+
     task_model_id = get_task_model_id(
         body["model"],
         request.app.state.config.TASK_MODEL,
@@ -189,17 +192,33 @@ async def chat_completion_tools_handler(
                 tool_function_params = tool_call.get("parameters", {})
 
                 try:
-                    spec = tools[tool_function_name].get("spec", {})
+                    tool = tools[tool_function_name]
+
+                    spec = tool.get("spec", {})
                     allowed_params = (
                         spec.get("parameters", {}).get("properties", {}).keys()
                     )
-                    tool_function = tools[tool_function_name]["callable"]
+                    tool_function = tool["callable"]
                     tool_function_params = {
                         k: v
                         for k, v in tool_function_params.items()
                         if k in allowed_params
                     }
-                    tool_output = await tool_function(**tool_function_params)
+
+                    if tool.get("direct", False):
+                        tool_output = await tool_function(**tool_function_params)
+                    else:
+                        tool_output = await event_caller(
+                            {
+                                "type": "execute:tool",
+                                "data": {
+                                    "id": str(uuid4()),
+                                    "tool": tool,
+                                    "params": tool_function_params,
+                                    "session_id": metadata.get("session_id", None),
+                                },
+                            }
+                        )
 
                 except Exception as e:
                     tool_output = str(e)
@@ -764,12 +783,18 @@ async def process_chat_payload(request, form_data, user, metadata, model):
     }
     form_data["metadata"] = metadata
 
+    # Server side tools
     tool_ids = metadata.get("tool_ids", None)
+    # Client side tools
+    tool_specs = form_data.get("tool_specs", None)
+
     log.debug(f"{tool_ids=}")
+    log.debug(f"{tool_specs=}")
+
+    tools_dict = {}
 
     if tool_ids:
-        # If tool_ids field is present, then get the tools
-        tools = get_tools(
+        tools_dict = get_tools(
             request,
             tool_ids,
             user,
@@ -780,20 +805,30 @@ async def process_chat_payload(request, form_data, user, metadata, model):
                 "__files__": metadata.get("files", []),
             },
         )
-        log.info(f"{tools=}")
+        log.info(f"{tools_dict=}")
+
+    if tool_specs:
+        for tool in tool_specs:
+            callable = tool.pop("callable", None)
+            tools_dict[tool["name"]] = {
+                "direct": True,
+                "callable": callable,
+                "spec": tool,
+            }
 
+    if tools_dict:
         if metadata.get("function_calling") == "native":
             # If the function calling is native, then call the tools function calling handler
-            metadata["tools"] = tools
+            metadata["tools"] = tools_dict
             form_data["tools"] = [
                 {"type": "function", "function": tool.get("spec", {})}
-                for tool in tools.values()
+                for tool in tools_dict.values()
             ]
         else:
             # If the function calling is not native, then call the tools function calling handler
             try:
                 form_data, flags = await chat_completion_tools_handler(
-                    request, form_data, user, models, tools
+                    request, form_data, extra_params, user, models, tools_dict
                 )
                 sources.extend(flags.get("sources", []))
 
@@ -1774,9 +1809,25 @@ async def process_chat_response(
                                     for k, v in tool_function_params.items()
                                     if k in allowed_params
                                 }
-                                tool_result = await tool_function(
-                                    **tool_function_params
-                                )
+
+                                if tool.get("direct", False):
+                                    tool_result = await tool_function(
+                                        **tool_function_params
+                                    )
+                                else:
+                                    tool_result = await event_caller(
+                                        {
+                                            "type": "execute:tool",
+                                            "data": {
+                                                "id": str(uuid4()),
+                                                "tool": tool,
+                                                "params": tool_function_params,
+                                                "session_id": metadata.get(
+                                                    "session_id", None
+                                                ),
+                                            },
+                                        }
+                                    )
                             except Exception as e:
                                 tool_result = str(e)
 

+ 3 - 0
backend/open_webui/utils/tools.py

@@ -1,6 +1,9 @@
 import inspect
 import logging
 import re
+import inspect
+import uuid
+
 from typing import Any, Awaitable, Callable, get_type_hints
 from functools import update_wrapper, partial
 

+ 19 - 0
src/routes/+layout.svelte

@@ -203,6 +203,22 @@
 		};
 	};
 
+	const executeTool = async (data, cb) => {
+		console.log(data);
+		// TODO: MCP (SSE) support
+		// TODO: API Server support
+
+		if (cb) {
+			cb(
+				JSON.parse(
+					JSON.stringify({
+						result: null
+					})
+				)
+			);
+		}
+	};
+
 	const chatEventHandler = async (event, cb) => {
 		const chat = $page.url.pathname.includes(`/c/${event.chat_id}`);
 
@@ -256,6 +272,9 @@
 			if (type === 'execute:python') {
 				console.log('execute:python', data);
 				executePythonAsWorker(data.id, data.code, cb);
+			} else if (type === 'execute:tool') {
+				console.log('execute:tool', data);
+				executeTool(data, cb);
 			} else if (type === 'request:chat:completion') {
 				console.log(data, $socket.id);
 				const { session_id, channel, form_data, model } = data;