浏览代码

Merge pull request #4724 from michaelpoluektov/tools-refac-2.1

feat: Add `__tools__` optional param for function pipes
Timothy Jaeryang Baek 8 月之前
父节点
当前提交
ee526b4b07
共有 6 个文件被更改,包括 249 次插入133 次删除
  1. 2 5
      backend/apps/ollama/main.py
  2. 19 11
      backend/apps/webui/main.py
  3. 0 1
      backend/config.py
  4. 38 115
      backend/main.py
  5. 104 0
      backend/utils/schemas.py
  6. 86 1
      backend/utils/tools.py

+ 2 - 5
backend/apps/ollama/main.py

@@ -731,11 +731,8 @@ async def generate_chat_completion(
     url_idx: Optional[int] = None,
     user=Depends(get_verified_user),
 ):
-    log.debug(f"{form_data.model_dump_json(exclude_none=True).encode()}=")
-
-    payload = {
-        **form_data.model_dump(exclude_none=True, exclude=["metadata"]),
-    }
+    payload = {**form_data.model_dump(exclude_none=True)}
+    log.debug(f"{payload = }")
     if "metadata" in payload:
         del payload["metadata"]
 

+ 19 - 11
backend/apps/webui/main.py

@@ -26,6 +26,7 @@ from utils.misc import (
     apply_model_system_prompt_to_body,
 )
 
+from utils.tools import get_tools
 
 from config import (
     SHOW_ADMIN_DETAILS,
@@ -275,7 +276,9 @@ def get_function_params(function_module, form_data, user, extra_params={}):
 async def generate_function_chat_completion(form_data, user):
     model_id = form_data.get("model")
     model_info = Models.get_model_by_id(model_id)
-    metadata = form_data.pop("metadata", None)
+    metadata = form_data.pop("metadata", {})
+    files = metadata.get("files", [])
+    tool_ids = metadata.get("tool_ids", [])
 
     __event_emitter__ = None
     __event_call__ = None
@@ -287,6 +290,20 @@ async def generate_function_chat_completion(form_data, user):
             __event_call__ = get_event_call(metadata)
         __task__ = metadata.get("task", None)
 
+    extra_params = {
+        "__event_emitter__": __event_emitter__,
+        "__event_call__": __event_call__,
+        "__task__": __task__,
+    }
+    tools_params = {
+        **extra_params,
+        "__model__": app.state.MODELS[form_data["model"]],
+        "__messages__": form_data["messages"],
+        "__files__": files,
+    }
+    configured_tools = get_tools(app, tool_ids, user, tools_params)
+
+    extra_params["__tools__"] = configured_tools
     if model_info:
         if model_info.base_model_id:
             form_data["model"] = model_info.base_model_id
@@ -299,16 +316,7 @@ async def generate_function_chat_completion(form_data, user):
     function_module = get_function_module(pipe_id)
 
     pipe = function_module.pipe
-    params = get_function_params(
-        function_module,
-        form_data,
-        user,
-        {
-            "__event_emitter__": __event_emitter__,
-            "__event_call__": __event_call__,
-            "__task__": __task__,
-        },
-    )
+    params = get_function_params(function_module, form_data, user, extra_params)
 
     if form_data["stream"]:
 

+ 0 - 1
backend/config.py

@@ -176,7 +176,6 @@ for version in soup.find_all("h2"):
 
 CHANGELOG = changelog_json
 
-
 ####################################
 # SAFE_MODE
 ####################################

+ 38 - 115
backend/main.py

@@ -14,6 +14,7 @@ import requests
 import mimetypes
 import shutil
 import inspect
+from typing import Optional
 
 from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
 from fastapi.staticfiles import StaticFiles
@@ -51,15 +52,13 @@ from apps.webui.internal.db import Session
 
 
 from pydantic import BaseModel
-from typing import Optional, Callable, Awaitable
 
 from apps.webui.models.auths import Auths
 from apps.webui.models.models import Models
-from apps.webui.models.tools import Tools
 from apps.webui.models.functions import Functions
 from apps.webui.models.users import Users, UserModel
 
-from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
+from apps.webui.utils import load_function_module_by_id
 
 from utils.utils import (
     get_admin_user,
@@ -76,6 +75,8 @@ from utils.task import (
     tools_function_calling_generation_template,
     moa_response_generation_template,
 )
+
+from utils.tools import get_tools
 from utils.misc import (
     get_last_user_message,
     add_or_update_system_message,
@@ -325,8 +326,8 @@ async def chat_completion_filter_functions_handler(body, model, extra_params):
             print(f"Error: {e}")
             raise e
 
-    if skip_files and "files" in body:
-        del body["files"]
+    if skip_files and "files" in body.get("metadata", {}):
+        del body["metadata"]["files"]
 
     return body, {}
 
@@ -351,80 +352,6 @@ def get_tools_function_calling_payload(messages, task_model_id, content):
     }
 
 
-def apply_extra_params_to_tool_function(
-    function: Callable, extra_params: dict
-) -> Callable[..., Awaitable]:
-    sig = inspect.signature(function)
-    extra_params = {
-        key: value for key, value in extra_params.items() if key in sig.parameters
-    }
-    is_coroutine = inspect.iscoroutinefunction(function)
-
-    async def new_function(**kwargs):
-        extra_kwargs = kwargs | extra_params
-        if is_coroutine:
-            return await function(**extra_kwargs)
-        return function(**extra_kwargs)
-
-    return new_function
-
-
-# Mutation on extra_params
-def get_tools(
-    tool_ids: list[str], user: UserModel, extra_params: dict
-) -> dict[str, dict]:
-    tools = {}
-    for tool_id in tool_ids:
-        toolkit = Tools.get_tool_by_id(tool_id)
-        if toolkit is None:
-            continue
-
-        module = webui_app.state.TOOLS.get(tool_id, None)
-        if module is None:
-            module, _ = load_toolkit_module_by_id(tool_id)
-            webui_app.state.TOOLS[tool_id] = module
-
-        extra_params["__id__"] = tool_id
-        if hasattr(module, "valves") and hasattr(module, "Valves"):
-            valves = Tools.get_tool_valves_by_id(tool_id) or {}
-            module.valves = module.Valves(**valves)
-
-        if hasattr(module, "UserValves"):
-            extra_params["__user__"]["valves"] = module.UserValves(  # type: ignore
-                **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
-            )
-
-        for spec in toolkit.specs:
-            # TODO: Fix hack for OpenAI API
-            for val in spec.get("parameters", {}).get("properties", {}).values():
-                if val["type"] == "str":
-                    val["type"] = "string"
-            function_name = spec["name"]
-
-            # convert to function that takes only model params and inserts custom params
-            callable = apply_extra_params_to_tool_function(
-                getattr(module, function_name), extra_params
-            )
-
-            # TODO: This needs to be a pydantic model
-            tool_dict = {
-                "toolkit_id": tool_id,
-                "callable": callable,
-                "spec": spec,
-                "file_handler": hasattr(module, "file_handler") and module.file_handler,
-                "citation": hasattr(module, "citation") and module.citation,
-            }
-
-            # TODO: if collision, prepend toolkit name
-            if function_name in tools:
-                log.warning(f"Tool {function_name} already exists in another toolkit!")
-                log.warning(f"Collision between {toolkit} and {tool_id}.")
-                log.warning(f"Discarding {toolkit}.{function_name}")
-            else:
-                tools[function_name] = tool_dict
-    return tools
-
-
 async def get_content_from_response(response) -> Optional[str]:
     content = None
     if hasattr(response, "body_iterator"):
@@ -443,15 +370,17 @@ async def get_content_from_response(response) -> Optional[str]:
 async def chat_completion_tools_handler(
     body: dict, user: UserModel, extra_params: dict
 ) -> tuple[dict, dict]:
+    # If tool_ids field is present, call the functions
+    metadata = body.get("metadata", {})
+    tool_ids = metadata.get("tool_ids", None)
+    if not tool_ids:
+        return body, {}
+
     skip_files = False
     contexts = []
     citations = []
 
     task_model_id = get_task_model_id(body["model"])
-    # If tool_ids field is present, call the functions
-    tool_ids = body.pop("tool_ids", None)
-    if not tool_ids:
-        return body, {}
 
     log.debug(f"{tool_ids=}")
 
@@ -459,9 +388,9 @@ async def chat_completion_tools_handler(
         **extra_params,
         "__model__": app.state.MODELS[task_model_id],
         "__messages__": body["messages"],
-        "__files__": body.get("files", []),
+        "__files__": metadata.get("files", []),
     }
-    tools = get_tools(tool_ids, user, custom_params)
+    tools = get_tools(webui_app, tool_ids, user, custom_params)
     log.info(f"{tools=}")
 
     specs = [tool["spec"] for tool in tools.values()]
@@ -486,7 +415,7 @@ async def chat_completion_tools_handler(
         content = await get_content_from_response(response)
         log.debug(f"{content=}")
 
-        if content is None:
+        if not content:
             return body, {}
 
         result = json.loads(content)
@@ -521,13 +450,13 @@ async def chat_completion_tools_handler(
             contexts.append(tool_output)
 
     except Exception as e:
-        print(f"Error: {e}")
+        log.exception(f"Error: {e}")
         content = None
 
     log.debug(f"tool_contexts: {contexts}")
 
-    if skip_files and "files" in body:
-        del body["files"]
+    if skip_files and "files" in body.get("metadata", {}):
+        del body["metadata"]["files"]
 
     return body, {"contexts": contexts, "citations": citations}
 
@@ -536,7 +465,7 @@ async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]:
     contexts = []
     citations = []
 
-    if files := body.pop("files", None):
+    if files := body.get("metadata", {}).get("files", None):
         contexts, citations = get_rag_context(
             files=files,
             messages=body["messages"],
@@ -597,6 +526,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
             "message_id": body.pop("id", None),
             "session_id": body.pop("session_id", None),
             "valves": body.pop("valves", None),
+            "tool_ids": body.pop("tool_ids", None),
+            "files": body.pop("files", None),
         }
 
         __user__ = {
@@ -680,36 +611,29 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
         ]
 
         response = await call_next(request)
-        if isinstance(response, StreamingResponse):
-            # If it's a streaming response, inject it as SSE event or NDJSON line
-            content_type = response.headers["Content-Type"]
-            if "text/event-stream" in content_type:
-                return StreamingResponse(
-                    self.openai_stream_wrapper(response.body_iterator, data_items),
-                )
-            if "application/x-ndjson" in content_type:
-                return StreamingResponse(
-                    self.ollama_stream_wrapper(response.body_iterator, data_items),
-                )
+        if not isinstance(response, StreamingResponse):
+            return response
 
-        return response
+        content_type = response.headers["Content-Type"]
+        is_openai = "text/event-stream" in content_type
+        is_ollama = "application/x-ndjson" in content_type
+        if not is_openai and not is_ollama:
+            return response
 
-    async def _receive(self, body: bytes):
-        return {"type": "http.request", "body": body, "more_body": False}
+        def wrap_item(item):
+            return f"data: {item}\n\n" if is_openai else f"{item}\n"
 
-    async def openai_stream_wrapper(self, original_generator, data_items):
-        for item in data_items:
-            yield f"data: {json.dumps(item)}\n\n"
+        async def stream_wrapper(original_generator, data_items):
+            for item in data_items:
+                yield wrap_item(json.dumps(item))
 
-        async for data in original_generator:
-            yield data
+            async for data in original_generator:
+                yield data
 
-    async def ollama_stream_wrapper(self, original_generator, data_items):
-        for item in data_items:
-            yield f"{json.dumps(item)}\n"
+        return StreamingResponse(stream_wrapper(response.body_iterator, data_items))
 
-        async for data in original_generator:
-            yield data
+    async def _receive(self, body: bytes):
+        return {"type": "http.request", "body": body, "more_body": False}
 
 
 app.add_middleware(ChatCompletionMiddleware)
@@ -1065,7 +989,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
             detail="Model not found",
         )
     model = app.state.MODELS[model_id]
-
     if model.get("pipe"):
         return await generate_function_chat_completion(form_data, user=user)
     if model["owned_by"] == "ollama":

+ 104 - 0
backend/utils/schemas.py

@@ -0,0 +1,104 @@
+from pydantic import BaseModel, Field, create_model
+from typing import Any, Optional, Type
+
+
+def json_schema_to_model(tool_dict: dict[str, Any]) -> Type[BaseModel]:
+    """
+    Converts a JSON schema to a Pydantic BaseModel class.
+
+    Args:
+        json_schema: The JSON schema to convert.
+
+    Returns:
+        A Pydantic BaseModel class.
+    """
+
+    # Extract the model name from the schema title.
+    model_name = tool_dict["name"]
+    schema = tool_dict["parameters"]
+
+    # Extract the field definitions from the schema properties.
+    field_definitions = {
+        name: json_schema_to_pydantic_field(name, prop, schema.get("required", []))
+        for name, prop in schema.get("properties", {}).items()
+    }
+
+    # Create the BaseModel class using create_model().
+    return create_model(model_name, **field_definitions)
+
+
+def json_schema_to_pydantic_field(
+    name: str, json_schema: dict[str, Any], required: list[str]
+) -> Any:
+    """
+    Converts a JSON schema property to a Pydantic field definition.
+
+    Args:
+        name: The field name.
+        json_schema: The JSON schema property.
+
+    Returns:
+        A Pydantic field definition.
+    """
+
+    # Get the field type.
+    type_ = json_schema_to_pydantic_type(json_schema)
+
+    # Get the field description.
+    description = json_schema.get("description")
+
+    # Get the field examples.
+    examples = json_schema.get("examples")
+
+    # Create a Field object with the type, description, and examples.
+    # The 'required' flag will be set later when creating the model.
+    return (
+        type_,
+        Field(
+            description=description,
+            examples=examples,
+            default=... if name in required else None,
+        ),
+    )
+
+
+def json_schema_to_pydantic_type(json_schema: dict[str, Any]) -> Any:
+    """
+    Converts a JSON schema type to a Pydantic type.
+
+    Args:
+        json_schema: The JSON schema to convert.
+
+    Returns:
+        A Pydantic type.
+    """
+
+    type_ = json_schema.get("type")
+
+    if type_ == "string" or type_ == "str":
+        return str
+    elif type_ == "integer" or type_ == "int":
+        return int
+    elif type_ == "number" or type_ == "float":
+        return float
+    elif type_ == "boolean" or type_ == "bool":
+        return bool
+    elif type_ == "array":
+        items_schema = json_schema.get("items")
+        if items_schema:
+            item_type = json_schema_to_pydantic_type(items_schema)
+            return list[item_type]
+        else:
+            return list
+    elif type_ == "object":
+        # Handle nested models.
+        properties = json_schema.get("properties")
+        if properties:
+            nested_model = json_schema_to_model(json_schema)
+            return nested_model
+        else:
+            return dict
+    elif type_ == "null":
+        return Optional[Any]  # Use Optional[Any] for nullable fields
+    else:
+        raise ValueError(f"Unsupported JSON schema type: {type_}")

+ 86 - 1
backend/utils/tools.py

@@ -1,5 +1,90 @@
 import inspect
-from typing import get_type_hints
+import logging
+from typing import Awaitable, Callable, get_type_hints
+
+from apps.webui.models.tools import Tools
+from apps.webui.models.users import UserModel
+from apps.webui.utils import load_toolkit_module_by_id
+
+from utils.schemas import json_schema_to_model
+
+log = logging.getLogger(__name__)
+
+
+def apply_extra_params_to_tool_function(
+    function: Callable, extra_params: dict
+) -> Callable[..., Awaitable]:
+    sig = inspect.signature(function)
+    extra_params = {
+        key: value for key, value in extra_params.items() if key in sig.parameters
+    }
+    is_coroutine = inspect.iscoroutinefunction(function)
+
+    async def new_function(**kwargs):
+        extra_kwargs = kwargs | extra_params
+        if is_coroutine:
+            return await function(**extra_kwargs)
+        return function(**extra_kwargs)
+
+    return new_function
+
+
+# Mutation on extra_params
+def get_tools(
+    webui_app, tool_ids: list[str], user: UserModel, extra_params: dict
+) -> dict[str, dict]:
+    tools = {}
+    for tool_id in tool_ids:
+        toolkit = Tools.get_tool_by_id(tool_id)
+        if toolkit is None:
+            continue
+
+        module = webui_app.state.TOOLS.get(tool_id, None)
+        if module is None:
+            module, _ = load_toolkit_module_by_id(tool_id)
+            webui_app.state.TOOLS[tool_id] = module
+
+        extra_params["__id__"] = tool_id
+        if hasattr(module, "valves") and hasattr(module, "Valves"):
+            valves = Tools.get_tool_valves_by_id(tool_id) or {}
+            module.valves = module.Valves(**valves)
+
+        if hasattr(module, "UserValves"):
+            extra_params["__user__"]["valves"] = module.UserValves(  # type: ignore
+                **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
+            )
+
+        for spec in toolkit.specs:
+            # TODO: Fix hack for OpenAI API
+            for val in spec.get("parameters", {}).get("properties", {}).values():
+                if val["type"] == "str":
+                    val["type"] = "string"
+            function_name = spec["name"]
+
+            # convert to function that takes only model params and inserts custom params
+            original_func = getattr(module, function_name)
+            callable = apply_extra_params_to_tool_function(original_func, extra_params)
+            if hasattr(original_func, "__doc__"):
+                callable.__doc__ = original_func.__doc__
+
+            # TODO: This needs to be a pydantic model
+            tool_dict = {
+                "toolkit_id": tool_id,
+                "callable": callable,
+                "spec": spec,
+                "pydantic_model": json_schema_to_model(spec),
+                "file_handler": hasattr(module, "file_handler") and module.file_handler,
+                "citation": hasattr(module, "citation") and module.citation,
+            }
+
+            # TODO: if collision, prepend toolkit name
+            if function_name in tools:
+                log.warning(f"Tool {function_name} already exists in another toolkit!")
+                log.warning(f"Collision between {toolkit} and {tool_id}.")
+                log.warning(f"Discarding {toolkit}.{function_name}")
+            else:
+                tools[function_name] = tool_dict
+    return tools
 
 
 def doc_to_dict(docstring):