Browse Source

Merge pull request #7182 from michaelpoluektov/fix/tools-metadata

fix: Fix tools metadata
Timothy Jaeryang Baek 5 months ago
parent
commit
5be7cbfdf5

+ 3 - 1
backend/open_webui/apps/retrieval/main.py

@@ -598,7 +598,9 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
         app.state.config.BRAVE_SEARCH_API_KEY = (
             form_data.web.search.brave_search_api_key
         )
-        app.state.config.MOJEEK_SEARCH_API_KEY = form_data.web.search.mojeek_search_api_key
+        app.state.config.MOJEEK_SEARCH_API_KEY = (
+            form_data.web.search.mojeek_search_api_key
+        )
         app.state.config.SERPSTACK_API_KEY = form_data.web.search.serpstack_api_key
         app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https
         app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key

+ 2 - 3
backend/open_webui/apps/retrieval/web/mojeek.py

@@ -22,7 +22,7 @@ def search_mojeek(
     headers = {
         "Accept": "application/json",
     }
-    params = {"q": query, "api_key": api_key, 'fmt': 'json', 't': count}
+    params = {"q": query, "api_key": api_key, "fmt": "json", "t": count}
 
     response = requests.get(url, headers=headers, params=params)
     response.raise_for_status()
@@ -32,10 +32,9 @@ def search_mojeek(
     if filter_list:
         results = get_filtered_results(results, filter_list)
 
-
     return [
         SearchResult(
             link=result["url"], title=result.get("title"), snippet=result.get("desc")
         )
         for result in results
-    ]
+    ]

+ 25 - 29
backend/open_webui/apps/webui/routers/tools.py

@@ -1,4 +1,3 @@
-import os
 from pathlib import Path
 from typing import Optional
 
@@ -10,7 +9,7 @@ from open_webui.apps.webui.models.tools import (
     Tools,
 )
 from open_webui.apps.webui.utils import load_tools_module_by_id, replace_imports
-from open_webui.config import CACHE_DIR, DATA_DIR
+from open_webui.config import CACHE_DIR
 from open_webui.constants import ERROR_MESSAGES
 from fastapi import APIRouter, Depends, HTTPException, Request, status
 from open_webui.utils.tools import get_tools_specs
@@ -300,38 +299,35 @@ async def update_tools_valves_by_id(
     request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
 ):
     tools = Tools.get_tool_by_id(id)
-    if tools:
-        if id in request.app.state.TOOLS:
-            tools_module = request.app.state.TOOLS[id]
-        else:
-            tools_module, _ = load_tools_module_by_id(id)
-            request.app.state.TOOLS[id] = tools_module
-
-        if hasattr(tools_module, "Valves"):
-            Valves = tools_module.Valves
-
-            try:
-                form_data = {k: v for k, v in form_data.items() if v is not None}
-                valves = Valves(**form_data)
-                Tools.update_tool_valves_by_id(id, valves.model_dump())
-                return valves.model_dump()
-            except Exception as e:
-                print(e)
-                raise HTTPException(
-                    status_code=status.HTTP_400_BAD_REQUEST,
-                    detail=ERROR_MESSAGES.DEFAULT(str(e)),
-                )
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.NOT_FOUND,
-            )
-
+    if not tools:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+    if id in request.app.state.TOOLS:
+        tools_module = request.app.state.TOOLS[id]
     else:
+        tools_module, _ = load_tools_module_by_id(id)
+        request.app.state.TOOLS[id] = tools_module
+
+    if not hasattr(tools_module, "Valves"):
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,
             detail=ERROR_MESSAGES.NOT_FOUND,
         )
+    Valves = tools_module.Valves
+
+    try:
+        form_data = {k: v for k, v in form_data.items() if v is not None}
+        valves = Valves(**form_data)
+        Tools.update_tool_valves_by_id(id, valves.model_dump())
+        return valves.model_dump()
+    except Exception as e:
+        print(e)
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.DEFAULT(str(e)),
+        )
 
 
 ############################

+ 0 - 1
backend/open_webui/main.py

@@ -1313,7 +1313,6 @@ async def generate_chat_completions(
 
 @app.post("/api/chat/completed")
 async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
-
     model_list = await get_all_models()
     models = {model["id"]: model for model in model_list}
 

+ 0 - 112
backend/open_webui/utils/schemas.py

@@ -1,112 +0,0 @@
-from ast import literal_eval
-from typing import Any, Literal, Optional, Type
-
-from pydantic import BaseModel, Field, create_model
-
-
-def json_schema_to_model(tool_dict: dict[str, Any]) -> Type[BaseModel]:
-    """
-    Converts a JSON schema to a Pydantic BaseModel class.
-
-    Args:
-        json_schema: The JSON schema to convert.
-
-    Returns:
-        A Pydantic BaseModel class.
-    """
-
-    # Extract the model name from the schema title.
-    model_name = tool_dict["name"]
-    schema = tool_dict["parameters"]
-
-    # Extract the field definitions from the schema properties.
-    field_definitions = {
-        name: json_schema_to_pydantic_field(name, prop, schema.get("required", []))
-        for name, prop in schema.get("properties", {}).items()
-    }
-
-    # Create the BaseModel class using create_model().
-    return create_model(model_name, **field_definitions)
-
-
-def json_schema_to_pydantic_field(
-    name: str, json_schema: dict[str, Any], required: list[str]
-) -> Any:
-    """
-    Converts a JSON schema property to a Pydantic field definition.
-
-    Args:
-        name: The field name.
-        json_schema: The JSON schema property.
-
-    Returns:
-        A Pydantic field definition.
-    """
-
-    # Get the field type.
-    type_ = json_schema_to_pydantic_type(json_schema)
-
-    # Get the field description.
-    description = json_schema.get("description")
-
-    # Get the field examples.
-    examples = json_schema.get("examples")
-
-    # Create a Field object with the type, description, and examples.
-    # The 'required' flag will be set later when creating the model.
-    return (
-        type_,
-        Field(
-            description=description,
-            examples=examples,
-            default=... if name in required else None,
-        ),
-    )
-
-
-def json_schema_to_pydantic_type(json_schema: dict[str, Any]) -> Any:
-    """
-    Converts a JSON schema type to a Pydantic type.
-
-    Args:
-        json_schema: The JSON schema to convert.
-
-    Returns:
-        A Pydantic type.
-    """
-
-    type_ = json_schema.get("type")
-
-    if type_ == "string" or type_ == "str":
-        return str
-    elif type_ == "integer" or type_ == "int":
-        return int
-    elif type_ == "number" or type_ == "float":
-        return float
-    elif type_ == "boolean" or type_ == "bool":
-        return bool
-    elif type_ == "array" or type_ == "list":
-        items_schema = json_schema.get("items")
-        if items_schema:
-            item_type = json_schema_to_pydantic_type(items_schema)
-            return list[item_type]
-        else:
-            return list
-    elif type_ == "object":
-        # Handle nested models.
-        properties = json_schema.get("properties")
-        if properties:
-            nested_model = json_schema_to_model(json_schema)
-            return nested_model
-        else:
-            return dict
-    elif type_ == "null":
-        return Optional[Any]  # Use Optional[Any] for nullable fields
-    elif type_ == "literal":
-        return Literal[literal_eval(json_schema.get("enum"))]
-    elif type_ == "optional":
-        inner_schema = json_schema.get("items", {"type": "string"})
-        inner_type = json_schema_to_pydantic_type(inner_schema)
-        return Optional[inner_type]
-    else:
-        raise ValueError(f"Unsupported JSON schema type: {type_}")

+ 79 - 88
backend/open_webui/utils/tools.py

@@ -1,11 +1,14 @@
 import inspect
 import logging
-from typing import Awaitable, Callable, get_type_hints
+import re
+from typing import Any, Awaitable, Callable, get_type_hints
+from functools import update_wrapper, partial
 
+from langchain_core.utils.function_calling import convert_to_openai_function
 from open_webui.apps.webui.models.tools import Tools
 from open_webui.apps.webui.models.users import UserModel
 from open_webui.apps.webui.utils import load_tools_module_by_id
-from open_webui.utils.schemas import json_schema_to_model
+from pydantic import BaseModel, Field, create_model
 
 log = logging.getLogger(__name__)
 
@@ -14,17 +17,16 @@ def apply_extra_params_to_tool_function(
     function: Callable, extra_params: dict
 ) -> Callable[..., Awaitable]:
     sig = inspect.signature(function)
-    extra_params = {
-        key: value for key, value in extra_params.items() if key in sig.parameters
-    }
-    is_coroutine = inspect.iscoroutinefunction(function)
+    extra_params = {k: v for k, v in extra_params.items() if k in sig.parameters}
+    partial_func = partial(function, **extra_params)
+    if inspect.iscoroutinefunction(function):
+        update_wrapper(partial_func, function)
+        return partial_func
 
-    async def new_function(**kwargs):
-        extra_kwargs = kwargs | extra_params
-        if is_coroutine:
-            return await function(**extra_kwargs)
-        return function(**extra_kwargs)
+    async def new_function(*args, **kwargs):
+        return partial_func(*args, **kwargs)
 
+    update_wrapper(new_function, function)
     return new_function
 
 
@@ -55,11 +57,6 @@ def get_tools(
             )
 
         for spec in tools.specs:
-            # TODO: Fix hack for OpenAI API
-            for val in spec.get("parameters", {}).get("properties", {}).values():
-                if val["type"] == "str":
-                    val["type"] = "string"
-
             # Remove internal parameters
             spec["parameters"]["properties"] = {
                 key: val
@@ -72,15 +69,12 @@ def get_tools(
             # convert to function that takes only model params and inserts custom params
             original_func = getattr(module, function_name)
             callable = apply_extra_params_to_tool_function(original_func, extra_params)
-            if hasattr(original_func, "__doc__"):
-                callable.__doc__ = original_func.__doc__
-
             # TODO: This needs to be a pydantic model
             tool_dict = {
                 "toolkit_id": tool_id,
                 "callable": callable,
                 "spec": spec,
-                "pydantic_model": json_schema_to_model(spec),
+                "pydantic_model": function_to_pydantic_model(callable),
                 "file_handler": hasattr(module, "file_handler") and module.file_handler,
                 "citation": hasattr(module, "citation") and module.citation,
             }
@@ -96,78 +90,75 @@ def get_tools(
     return tools_dict
 
 
-def doc_to_dict(docstring):
-    lines = docstring.split("\n")
-    description = lines[1].strip()
-    param_dict = {}
+def parse_docstring(docstring):
+    """
+    Parse a function's docstring to extract parameter descriptions in reST format.
+
+    Args:
+        docstring (str): The docstring to parse.
+
+    Returns:
+        dict: A dictionary where keys are parameter names and values are descriptions.
+    """
+    if not docstring:
+        return {}
+
+    # Regex to match `:param name: description` format
+    param_pattern = re.compile(r":param (\w+):\s*(.+)")
+    param_descriptions = {}
+
+    for line in docstring.splitlines():
+        match = param_pattern.match(line.strip())
+        if match:
+            param_name, param_description = match.groups()
+            param_descriptions[param_name] = param_description
+
+    return param_descriptions
+
 
-    for line in lines:
-        if ":param" in line:
-            line = line.replace(":param", "").strip()
-            param, desc = line.split(":", 1)
-            param_dict[param.strip()] = desc.strip()
-    ret_dict = {"description": description, "params": param_dict}
-    return ret_dict
+def function_to_pydantic_model(func: Callable) -> type[BaseModel]:
+    """
+    Converts a Python function's type hints and docstring to a Pydantic model,
+    including support for nested types, default values, and descriptions.
 
+    Args:
+        func: The function whose type hints and docstring should be converted.
+        model_name: The name of the generated Pydantic model.
 
-def get_tools_specs(tools) -> list[dict]:
-    function_list = [
-        {"name": func, "function": getattr(tools, func)}
-        for func in dir(tools)
-        if callable(getattr(tools, func))
+    Returns:
+        A Pydantic model class.
+    """
+    type_hints = get_type_hints(func)
+    signature = inspect.signature(func)
+    parameters = signature.parameters
+
+    docstring = func.__doc__
+    descriptions = parse_docstring(docstring)
+
+    field_defs = {}
+    for name, param in parameters.items():
+        type_hint = type_hints.get(name, Any)
+        default_value = param.default if param.default is not param.empty else ...
+        description = descriptions.get(name, None)
+        if not description:
+            field_defs[name] = type_hint, default_value
+            continue
+        field_defs[name] = type_hint, Field(default_value, description=description)
+
+    return create_model(func.__name__, **field_defs)
+
+
+def get_callable_attributes(tool: object) -> list[Callable]:
+    return [
+        getattr(tool, func)
+        for func in dir(tool)
+        if callable(getattr(tool, func))
         and not func.startswith("__")
-        and not inspect.isclass(getattr(tools, func))
+        and not inspect.isclass(getattr(tool, func))
     ]
 
-    specs = []
-    for function_item in function_list:
-        function_name = function_item["name"]
-        function = function_item["function"]
-
-        function_doc = doc_to_dict(function.__doc__ or function_name)
-        specs.append(
-            {
-                "name": function_name,
-                # TODO: multi-line desc?
-                "description": function_doc.get("description", function_name),
-                "parameters": {
-                    "type": "object",
-                    "properties": {
-                        param_name: {
-                            "type": param_annotation.__name__.lower(),
-                            **(
-                                {
-                                    "enum": (
-                                        str(param_annotation.__args__)
-                                        if hasattr(param_annotation, "__args__")
-                                        else None
-                                    )
-                                }
-                                if hasattr(param_annotation, "__args__")
-                                else {}
-                            ),
-                            "description": function_doc.get("params", {}).get(
-                                param_name, param_name
-                            ),
-                        }
-                        for param_name, param_annotation in get_type_hints(
-                            function
-                        ).items()
-                        if param_name != "return"
-                        and not (
-                            param_name.startswith("__") and param_name.endswith("__")
-                        )
-                    },
-                    "required": [
-                        name
-                        for name, param in inspect.signature(
-                            function
-                        ).parameters.items()
-                        if param.default is param.empty
-                        and not (name.startswith("__") and name.endswith("__"))
-                    ],
-                },
-            }
-        )
 
-    return specs
+def get_tools_specs(tool_class: object) -> list[dict]:
+    function_list = get_callable_attributes(tool_class)
+    models = map(function_to_pydantic_model, function_list)
+    return [convert_to_openai_function(tool) for tool in models]