浏览代码

fix tools metadata

Michael Poluektov 5 月之前
父节点
当前提交
70838148e7
共有 3 个文件被更改,包括 108 次插入120 次删除
  1. 25 29
      backend/open_webui/apps/webui/routers/tools.py
  2. 4 1
      backend/open_webui/utils/schemas.py
  3. 79 90
      backend/open_webui/utils/tools.py

+ 25 - 29
backend/open_webui/apps/webui/routers/tools.py

@@ -1,4 +1,3 @@
-import os
 from pathlib import Path
 from typing import Optional
 
@@ -10,7 +9,7 @@ from open_webui.apps.webui.models.tools import (
     Tools,
 )
 from open_webui.apps.webui.utils import load_tools_module_by_id, replace_imports
-from open_webui.config import CACHE_DIR, DATA_DIR
+from open_webui.config import CACHE_DIR
 from open_webui.constants import ERROR_MESSAGES
 from fastapi import APIRouter, Depends, HTTPException, Request, status
 from open_webui.utils.tools import get_tools_specs
@@ -300,38 +299,35 @@ async def update_tools_valves_by_id(
     request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
 ):
     tools = Tools.get_tool_by_id(id)
-    if tools:
-        if id in request.app.state.TOOLS:
-            tools_module = request.app.state.TOOLS[id]
-        else:
-            tools_module, _ = load_tools_module_by_id(id)
-            request.app.state.TOOLS[id] = tools_module
-
-        if hasattr(tools_module, "Valves"):
-            Valves = tools_module.Valves
-
-            try:
-                form_data = {k: v for k, v in form_data.items() if v is not None}
-                valves = Valves(**form_data)
-                Tools.update_tool_valves_by_id(id, valves.model_dump())
-                return valves.model_dump()
-            except Exception as e:
-                print(e)
-                raise HTTPException(
-                    status_code=status.HTTP_400_BAD_REQUEST,
-                    detail=ERROR_MESSAGES.DEFAULT(str(e)),
-                )
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.NOT_FOUND,
-            )
-
+    if not tools:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+    if id in request.app.state.TOOLS:
+        tools_module = request.app.state.TOOLS[id]
     else:
+        tools_module, _ = load_tools_module_by_id(id)
+        request.app.state.TOOLS[id] = tools_module
+
+    if not hasattr(tools_module, "Valves"):
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,
             detail=ERROR_MESSAGES.NOT_FOUND,
         )
+    Valves = tools_module.Valves
+
+    try:
+        form_data = {k: v for k, v in form_data.items() if v is not None}
+        valves = Valves(**form_data)
+        Tools.update_tool_valves_by_id(id, valves.model_dump())
+        return valves.model_dump()
+    except Exception as e:
+        print(e)
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.DEFAULT(str(e)),
+        )
 
 
 ############################

+ 4 - 1
backend/open_webui/utils/schemas.py

@@ -103,7 +103,10 @@ def json_schema_to_pydantic_type(json_schema: dict[str, Any]) -> Any:
     elif type_ == "null":
         return Optional[Any]  # Use Optional[Any] for nullable fields
     elif type_ == "literal":
-        return Literal[literal_eval(json_schema.get("enum"))]
+        enum = json_schema.get("enum")
+        if enum is None:
+            raise ValueError("Enum values must be provided for 'literal' type.")
+        return Literal[literal_eval(enum)]
     elif type_ == "optional":
         inner_schema = json_schema.get("items", {"type": "string"})
         inner_type = json_schema_to_pydantic_type(inner_schema)

+ 79 - 90
backend/open_webui/utils/tools.py

@@ -1,11 +1,14 @@
 import inspect
 import logging
-from typing import Awaitable, Callable, get_type_hints
+import re
+from typing import Any, Awaitable, Callable, get_type_hints
+from functools import update_wrapper, partial
 
+from langchain_core.utils.function_calling import convert_to_openai_function
 from open_webui.apps.webui.models.tools import Tools
 from open_webui.apps.webui.models.users import UserModel
 from open_webui.apps.webui.utils import load_tools_module_by_id
-from open_webui.utils.schemas import json_schema_to_model
+from pydantic import BaseModel, Field, create_model
 
 log = logging.getLogger(__name__)
 
@@ -13,18 +16,15 @@ log = logging.getLogger(__name__)
 def apply_extra_params_to_tool_function(
     function: Callable, extra_params: dict
 ) -> Callable[..., Awaitable]:
-    sig = inspect.signature(function)
-    extra_params = {
-        key: value for key, value in extra_params.items() if key in sig.parameters
-    }
-    is_coroutine = inspect.iscoroutinefunction(function)
-
-    async def new_function(**kwargs):
-        extra_kwargs = kwargs | extra_params
-        if is_coroutine:
-            return await function(**extra_kwargs)
-        return function(**extra_kwargs)
+    partial_func = partial(function, **extra_params)
+    if inspect.iscoroutinefunction(function):
+        update_wrapper(partial_func, function)
+        return partial_func
 
+    async def new_function(*args, **kwargs):
+        return partial_func(*args, **kwargs)
+
+    update_wrapper(new_function, function)
     return new_function
 
 
@@ -55,11 +55,6 @@ def get_tools(
             )
 
         for spec in tools.specs:
-            # TODO: Fix hack for OpenAI API
-            for val in spec.get("parameters", {}).get("properties", {}).values():
-                if val["type"] == "str":
-                    val["type"] = "string"
-
             # Remove internal parameters
             spec["parameters"]["properties"] = {
                 key: val
@@ -72,15 +67,12 @@ def get_tools(
             # convert to function that takes only model params and inserts custom params
             original_func = getattr(module, function_name)
             callable = apply_extra_params_to_tool_function(original_func, extra_params)
-            if hasattr(original_func, "__doc__"):
-                callable.__doc__ = original_func.__doc__
-
             # TODO: This needs to be a pydantic model
             tool_dict = {
                 "toolkit_id": tool_id,
                 "callable": callable,
                 "spec": spec,
-                "pydantic_model": json_schema_to_model(spec),
+                "pydantic_model": function_to_pydantic_model(callable),
                 "file_handler": hasattr(module, "file_handler") and module.file_handler,
                 "citation": hasattr(module, "citation") and module.citation,
             }
@@ -96,78 +88,75 @@ def get_tools(
     return tools_dict
 
 
-def doc_to_dict(docstring):
-    lines = docstring.split("\n")
-    description = lines[1].strip()
-    param_dict = {}
+def parse_docstring(docstring):
+    """
+    Parse a function's docstring to extract parameter descriptions in reST format.
+
+    Args:
+        docstring (str): The docstring to parse.
+
+    Returns:
+        dict: A dictionary where keys are parameter names and values are descriptions.
+    """
+    if not docstring:
+        return {}
+
+    # Regex to match `:param name: description` format
+    param_pattern = re.compile(r":param (\w+):\s*(.+)")
+    param_descriptions = {}
+
+    for line in docstring.splitlines():
+        match = param_pattern.match(line.strip())
+        if match:
+            param_name, param_description = match.groups()
+            param_descriptions[param_name] = param_description
+
+    return param_descriptions
 
-    for line in lines:
-        if ":param" in line:
-            line = line.replace(":param", "").strip()
-            param, desc = line.split(":", 1)
-            param_dict[param.strip()] = desc.strip()
-    ret_dict = {"description": description, "params": param_dict}
-    return ret_dict
 
+def function_to_pydantic_model(func: Callable) -> type[BaseModel]:
+    """
+    Converts a Python function's type hints and docstring to a Pydantic model,
+    including support for nested types, default values, and descriptions.
 
-def get_tools_specs(tools) -> list[dict]:
-    function_list = [
-        {"name": func, "function": getattr(tools, func)}
-        for func in dir(tools)
-        if callable(getattr(tools, func))
+    Args:
+        func: The function whose type hints and docstring should be converted.
+        model_name: The name of the generated Pydantic model.
+
+    Returns:
+        A Pydantic model class.
+    """
+    type_hints = get_type_hints(func)
+    signature = inspect.signature(func)
+    parameters = signature.parameters
+
+    docstring = func.__doc__
+    descriptions = parse_docstring(docstring)
+
+    field_defs = {}
+    for name, param in parameters.items():
+        type_hint = type_hints.get(name, Any)
+        default_value = param.default if param.default is not param.empty else ...
+        description = descriptions.get(name, None)
+        if not description:
+            field_defs[name] = type_hint, default_value
+            continue
+        field_defs[name] = type_hint, Field(default_value, description=description)
+
+    return create_model(func.__name__, **field_defs)
+
+
+def get_callable_attributes(tool: object) -> list[Callable]:
+    return [
+        getattr(tool, func)
+        for func in dir(tool)
+        if callable(getattr(tool, func))
         and not func.startswith("__")
-        and not inspect.isclass(getattr(tools, func))
+        and not inspect.isclass(getattr(tool, func))
     ]
 
-    specs = []
-    for function_item in function_list:
-        function_name = function_item["name"]
-        function = function_item["function"]
-
-        function_doc = doc_to_dict(function.__doc__ or function_name)
-        specs.append(
-            {
-                "name": function_name,
-                # TODO: multi-line desc?
-                "description": function_doc.get("description", function_name),
-                "parameters": {
-                    "type": "object",
-                    "properties": {
-                        param_name: {
-                            "type": param_annotation.__name__.lower(),
-                            **(
-                                {
-                                    "enum": (
-                                        str(param_annotation.__args__)
-                                        if hasattr(param_annotation, "__args__")
-                                        else None
-                                    )
-                                }
-                                if hasattr(param_annotation, "__args__")
-                                else {}
-                            ),
-                            "description": function_doc.get("params", {}).get(
-                                param_name, param_name
-                            ),
-                        }
-                        for param_name, param_annotation in get_type_hints(
-                            function
-                        ).items()
-                        if param_name != "return"
-                        and not (
-                            param_name.startswith("__") and param_name.endswith("__")
-                        )
-                    },
-                    "required": [
-                        name
-                        for name, param in inspect.signature(
-                            function
-                        ).parameters.items()
-                        if param.default is param.empty
-                        and not (name.startswith("__") and name.endswith("__"))
-                    ],
-                },
-            }
-        )
 
-    return specs
+def get_tools_specs(tool_class: object) -> list[dict]:
+    function_list = get_callable_attributes(tool_class)
+    models = map(function_to_pydantic_model, function_list)
+    return [convert_to_openai_function(tool) for tool in models]