Browse Source

Merge pull request #4724 from michaelpoluektov/tools-refac-2.1

feat: Add `__tools__` optional param for function pipes
Timothy Jaeryang Baek 8 months ago
parent
commit
ee526b4b07
6 changed files with 249 additions and 133 deletions
  1. 2 5
      backend/apps/ollama/main.py
  2. 19 11
      backend/apps/webui/main.py
  3. 0 1
      backend/config.py
  4. 38 115
      backend/main.py
  5. 104 0
      backend/utils/schemas.py
  6. 86 1
      backend/utils/tools.py

+ 2 - 5
backend/apps/ollama/main.py

@@ -731,11 +731,8 @@ async def generate_chat_completion(
     url_idx: Optional[int] = None,
     url_idx: Optional[int] = None,
     user=Depends(get_verified_user),
     user=Depends(get_verified_user),
 ):
 ):
-    log.debug(f"{form_data.model_dump_json(exclude_none=True).encode()}=")
-
-    payload = {
-        **form_data.model_dump(exclude_none=True, exclude=["metadata"]),
-    }
+    payload = {**form_data.model_dump(exclude_none=True)}
+    log.debug(f"{payload = }")
     if "metadata" in payload:
     if "metadata" in payload:
         del payload["metadata"]
         del payload["metadata"]
 
 

+ 19 - 11
backend/apps/webui/main.py

@@ -26,6 +26,7 @@ from utils.misc import (
     apply_model_system_prompt_to_body,
     apply_model_system_prompt_to_body,
 )
 )
 
 
+from utils.tools import get_tools
 
 
 from config import (
 from config import (
     SHOW_ADMIN_DETAILS,
     SHOW_ADMIN_DETAILS,
@@ -275,7 +276,9 @@ def get_function_params(function_module, form_data, user, extra_params={}):
 async def generate_function_chat_completion(form_data, user):
 async def generate_function_chat_completion(form_data, user):
     model_id = form_data.get("model")
     model_id = form_data.get("model")
     model_info = Models.get_model_by_id(model_id)
     model_info = Models.get_model_by_id(model_id)
-    metadata = form_data.pop("metadata", None)
+    metadata = form_data.pop("metadata", {})
+    files = metadata.get("files", [])
+    tool_ids = metadata.get("tool_ids", [])
 
 
     __event_emitter__ = None
     __event_emitter__ = None
     __event_call__ = None
     __event_call__ = None
@@ -287,6 +290,20 @@ async def generate_function_chat_completion(form_data, user):
             __event_call__ = get_event_call(metadata)
             __event_call__ = get_event_call(metadata)
         __task__ = metadata.get("task", None)
         __task__ = metadata.get("task", None)
 
 
+    extra_params = {
+        "__event_emitter__": __event_emitter__,
+        "__event_call__": __event_call__,
+        "__task__": __task__,
+    }
+    tools_params = {
+        **extra_params,
+        "__model__": app.state.MODELS[form_data["model"]],
+        "__messages__": form_data["messages"],
+        "__files__": files,
+    }
+    configured_tools = get_tools(app, tool_ids, user, tools_params)
+
+    extra_params["__tools__"] = configured_tools
     if model_info:
     if model_info:
         if model_info.base_model_id:
         if model_info.base_model_id:
             form_data["model"] = model_info.base_model_id
             form_data["model"] = model_info.base_model_id
@@ -299,16 +316,7 @@ async def generate_function_chat_completion(form_data, user):
     function_module = get_function_module(pipe_id)
     function_module = get_function_module(pipe_id)
 
 
     pipe = function_module.pipe
     pipe = function_module.pipe
-    params = get_function_params(
-        function_module,
-        form_data,
-        user,
-        {
-            "__event_emitter__": __event_emitter__,
-            "__event_call__": __event_call__,
-            "__task__": __task__,
-        },
-    )
+    params = get_function_params(function_module, form_data, user, extra_params)
 
 
     if form_data["stream"]:
     if form_data["stream"]:
 
 

+ 0 - 1
backend/config.py

@@ -176,7 +176,6 @@ for version in soup.find_all("h2"):
 
 
 CHANGELOG = changelog_json
 CHANGELOG = changelog_json
 
 
-
 ####################################
 ####################################
 # SAFE_MODE
 # SAFE_MODE
 ####################################
 ####################################

+ 38 - 115
backend/main.py

@@ -14,6 +14,7 @@ import requests
 import mimetypes
 import mimetypes
 import shutil
 import shutil
 import inspect
 import inspect
+from typing import Optional
 
 
 from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
 from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
 from fastapi.staticfiles import StaticFiles
 from fastapi.staticfiles import StaticFiles
@@ -51,15 +52,13 @@ from apps.webui.internal.db import Session
 
 
 
 
 from pydantic import BaseModel
 from pydantic import BaseModel
-from typing import Optional, Callable, Awaitable
 
 
 from apps.webui.models.auths import Auths
 from apps.webui.models.auths import Auths
 from apps.webui.models.models import Models
 from apps.webui.models.models import Models
-from apps.webui.models.tools import Tools
 from apps.webui.models.functions import Functions
 from apps.webui.models.functions import Functions
 from apps.webui.models.users import Users, UserModel
 from apps.webui.models.users import Users, UserModel
 
 
-from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
+from apps.webui.utils import load_function_module_by_id
 
 
 from utils.utils import (
 from utils.utils import (
     get_admin_user,
     get_admin_user,
@@ -76,6 +75,8 @@ from utils.task import (
     tools_function_calling_generation_template,
     tools_function_calling_generation_template,
     moa_response_generation_template,
     moa_response_generation_template,
 )
 )
+
+from utils.tools import get_tools
 from utils.misc import (
 from utils.misc import (
     get_last_user_message,
     get_last_user_message,
     add_or_update_system_message,
     add_or_update_system_message,
@@ -325,8 +326,8 @@ async def chat_completion_filter_functions_handler(body, model, extra_params):
             print(f"Error: {e}")
             print(f"Error: {e}")
             raise e
             raise e
 
 
-    if skip_files and "files" in body:
-        del body["files"]
+    if skip_files and "files" in body.get("metadata", {}):
+        del body["metadata"]["files"]
 
 
     return body, {}
     return body, {}
 
 
@@ -351,80 +352,6 @@ def get_tools_function_calling_payload(messages, task_model_id, content):
     }
     }
 
 
 
 
-def apply_extra_params_to_tool_function(
-    function: Callable, extra_params: dict
-) -> Callable[..., Awaitable]:
-    sig = inspect.signature(function)
-    extra_params = {
-        key: value for key, value in extra_params.items() if key in sig.parameters
-    }
-    is_coroutine = inspect.iscoroutinefunction(function)
-
-    async def new_function(**kwargs):
-        extra_kwargs = kwargs | extra_params
-        if is_coroutine:
-            return await function(**extra_kwargs)
-        return function(**extra_kwargs)
-
-    return new_function
-
-
-# Mutation on extra_params
-def get_tools(
-    tool_ids: list[str], user: UserModel, extra_params: dict
-) -> dict[str, dict]:
-    tools = {}
-    for tool_id in tool_ids:
-        toolkit = Tools.get_tool_by_id(tool_id)
-        if toolkit is None:
-            continue
-
-        module = webui_app.state.TOOLS.get(tool_id, None)
-        if module is None:
-            module, _ = load_toolkit_module_by_id(tool_id)
-            webui_app.state.TOOLS[tool_id] = module
-
-        extra_params["__id__"] = tool_id
-        if hasattr(module, "valves") and hasattr(module, "Valves"):
-            valves = Tools.get_tool_valves_by_id(tool_id) or {}
-            module.valves = module.Valves(**valves)
-
-        if hasattr(module, "UserValves"):
-            extra_params["__user__"]["valves"] = module.UserValves(  # type: ignore
-                **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
-            )
-
-        for spec in toolkit.specs:
-            # TODO: Fix hack for OpenAI API
-            for val in spec.get("parameters", {}).get("properties", {}).values():
-                if val["type"] == "str":
-                    val["type"] = "string"
-            function_name = spec["name"]
-
-            # convert to function that takes only model params and inserts custom params
-            callable = apply_extra_params_to_tool_function(
-                getattr(module, function_name), extra_params
-            )
-
-            # TODO: This needs to be a pydantic model
-            tool_dict = {
-                "toolkit_id": tool_id,
-                "callable": callable,
-                "spec": spec,
-                "file_handler": hasattr(module, "file_handler") and module.file_handler,
-                "citation": hasattr(module, "citation") and module.citation,
-            }
-
-            # TODO: if collision, prepend toolkit name
-            if function_name in tools:
-                log.warning(f"Tool {function_name} already exists in another toolkit!")
-                log.warning(f"Collision between {toolkit} and {tool_id}.")
-                log.warning(f"Discarding {toolkit}.{function_name}")
-            else:
-                tools[function_name] = tool_dict
-    return tools
-
-
 async def get_content_from_response(response) -> Optional[str]:
 async def get_content_from_response(response) -> Optional[str]:
     content = None
     content = None
     if hasattr(response, "body_iterator"):
     if hasattr(response, "body_iterator"):
@@ -443,15 +370,17 @@ async def get_content_from_response(response) -> Optional[str]:
 async def chat_completion_tools_handler(
 async def chat_completion_tools_handler(
     body: dict, user: UserModel, extra_params: dict
     body: dict, user: UserModel, extra_params: dict
 ) -> tuple[dict, dict]:
 ) -> tuple[dict, dict]:
+    # If tool_ids field is present, call the functions
+    metadata = body.get("metadata", {})
+    tool_ids = metadata.get("tool_ids", None)
+    if not tool_ids:
+        return body, {}
+
     skip_files = False
     skip_files = False
     contexts = []
     contexts = []
     citations = []
     citations = []
 
 
     task_model_id = get_task_model_id(body["model"])
     task_model_id = get_task_model_id(body["model"])
-    # If tool_ids field is present, call the functions
-    tool_ids = body.pop("tool_ids", None)
-    if not tool_ids:
-        return body, {}
 
 
     log.debug(f"{tool_ids=}")
     log.debug(f"{tool_ids=}")
 
 
@@ -459,9 +388,9 @@ async def chat_completion_tools_handler(
         **extra_params,
         **extra_params,
         "__model__": app.state.MODELS[task_model_id],
         "__model__": app.state.MODELS[task_model_id],
         "__messages__": body["messages"],
         "__messages__": body["messages"],
-        "__files__": body.get("files", []),
+        "__files__": metadata.get("files", []),
     }
     }
-    tools = get_tools(tool_ids, user, custom_params)
+    tools = get_tools(webui_app, tool_ids, user, custom_params)
     log.info(f"{tools=}")
     log.info(f"{tools=}")
 
 
     specs = [tool["spec"] for tool in tools.values()]
     specs = [tool["spec"] for tool in tools.values()]
@@ -486,7 +415,7 @@ async def chat_completion_tools_handler(
         content = await get_content_from_response(response)
         content = await get_content_from_response(response)
         log.debug(f"{content=}")
         log.debug(f"{content=}")
 
 
-        if content is None:
+        if not content:
             return body, {}
             return body, {}
 
 
         result = json.loads(content)
         result = json.loads(content)
@@ -521,13 +450,13 @@ async def chat_completion_tools_handler(
             contexts.append(tool_output)
             contexts.append(tool_output)
 
 
     except Exception as e:
     except Exception as e:
-        print(f"Error: {e}")
+        log.exception(f"Error: {e}")
         content = None
         content = None
 
 
     log.debug(f"tool_contexts: {contexts}")
     log.debug(f"tool_contexts: {contexts}")
 
 
-    if skip_files and "files" in body:
-        del body["files"]
+    if skip_files and "files" in body.get("metadata", {}):
+        del body["metadata"]["files"]
 
 
     return body, {"contexts": contexts, "citations": citations}
     return body, {"contexts": contexts, "citations": citations}
 
 
@@ -536,7 +465,7 @@ async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]:
     contexts = []
     contexts = []
     citations = []
     citations = []
 
 
-    if files := body.pop("files", None):
+    if files := body.get("metadata", {}).get("files", None):
         contexts, citations = get_rag_context(
         contexts, citations = get_rag_context(
             files=files,
             files=files,
             messages=body["messages"],
             messages=body["messages"],
@@ -597,6 +526,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
             "message_id": body.pop("id", None),
             "message_id": body.pop("id", None),
             "session_id": body.pop("session_id", None),
             "session_id": body.pop("session_id", None),
             "valves": body.pop("valves", None),
             "valves": body.pop("valves", None),
+            "tool_ids": body.pop("tool_ids", None),
+            "files": body.pop("files", None),
         }
         }
 
 
         __user__ = {
         __user__ = {
@@ -680,36 +611,29 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
         ]
         ]
 
 
         response = await call_next(request)
         response = await call_next(request)
-        if isinstance(response, StreamingResponse):
-            # If it's a streaming response, inject it as SSE event or NDJSON line
-            content_type = response.headers["Content-Type"]
-            if "text/event-stream" in content_type:
-                return StreamingResponse(
-                    self.openai_stream_wrapper(response.body_iterator, data_items),
-                )
-            if "application/x-ndjson" in content_type:
-                return StreamingResponse(
-                    self.ollama_stream_wrapper(response.body_iterator, data_items),
-                )
+        if not isinstance(response, StreamingResponse):
+            return response
 
 
-        return response
+        content_type = response.headers["Content-Type"]
+        is_openai = "text/event-stream" in content_type
+        is_ollama = "application/x-ndjson" in content_type
+        if not is_openai and not is_ollama:
+            return response
 
 
-    async def _receive(self, body: bytes):
-        return {"type": "http.request", "body": body, "more_body": False}
+        def wrap_item(item):
+            return f"data: {item}\n\n" if is_openai else f"{item}\n"
 
 
-    async def openai_stream_wrapper(self, original_generator, data_items):
-        for item in data_items:
-            yield f"data: {json.dumps(item)}\n\n"
+        async def stream_wrapper(original_generator, data_items):
+            for item in data_items:
+                yield wrap_item(json.dumps(item))
 
 
-        async for data in original_generator:
-            yield data
+            async for data in original_generator:
+                yield data
 
 
-    async def ollama_stream_wrapper(self, original_generator, data_items):
-        for item in data_items:
-            yield f"{json.dumps(item)}\n"
+        return StreamingResponse(stream_wrapper(response.body_iterator, data_items))
 
 
-        async for data in original_generator:
-            yield data
+    async def _receive(self, body: bytes):
+        return {"type": "http.request", "body": body, "more_body": False}
 
 
 
 
 app.add_middleware(ChatCompletionMiddleware)
 app.add_middleware(ChatCompletionMiddleware)
@@ -1065,7 +989,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
             detail="Model not found",
             detail="Model not found",
         )
         )
     model = app.state.MODELS[model_id]
     model = app.state.MODELS[model_id]
-
     if model.get("pipe"):
     if model.get("pipe"):
         return await generate_function_chat_completion(form_data, user=user)
         return await generate_function_chat_completion(form_data, user=user)
     if model["owned_by"] == "ollama":
     if model["owned_by"] == "ollama":

+ 104 - 0
backend/utils/schemas.py

@@ -0,0 +1,104 @@
+from pydantic import BaseModel, Field, create_model
+from typing import Any, Optional, Type
+
+
+def json_schema_to_model(tool_dict: dict[str, Any]) -> Type[BaseModel]:
+    """
+    Converts a JSON schema to a Pydantic BaseModel class.
+
+    Args:
+        json_schema: The JSON schema to convert.
+
+    Returns:
+        A Pydantic BaseModel class.
+    """
+
+    # Extract the model name from the schema title.
+    model_name = tool_dict["name"]
+    schema = tool_dict["parameters"]
+
+    # Extract the field definitions from the schema properties.
+    field_definitions = {
+        name: json_schema_to_pydantic_field(name, prop, schema.get("required", []))
+        for name, prop in schema.get("properties", {}).items()
+    }
+
+    # Create the BaseModel class using create_model().
+    return create_model(model_name, **field_definitions)
+
+
+def json_schema_to_pydantic_field(
+    name: str, json_schema: dict[str, Any], required: list[str]
+) -> Any:
+    """
+    Converts a JSON schema property to a Pydantic field definition.
+
+    Args:
+        name: The field name.
+        json_schema: The JSON schema property.
+
+    Returns:
+        A Pydantic field definition.
+    """
+
+    # Get the field type.
+    type_ = json_schema_to_pydantic_type(json_schema)
+
+    # Get the field description.
+    description = json_schema.get("description")
+
+    # Get the field examples.
+    examples = json_schema.get("examples")
+
+    # Create a Field object with the type, description, and examples.
+    # The 'required' flag will be set later when creating the model.
+    return (
+        type_,
+        Field(
+            description=description,
+            examples=examples,
+            default=... if name in required else None,
+        ),
+    )
+
+
+def json_schema_to_pydantic_type(json_schema: dict[str, Any]) -> Any:
+    """
+    Converts a JSON schema type to a Pydantic type.
+
+    Args:
+        json_schema: The JSON schema to convert.
+
+    Returns:
+        A Pydantic type.
+    """
+
+    type_ = json_schema.get("type")
+
+    if type_ == "string" or type_ == "str":
+        return str
+    elif type_ == "integer" or type_ == "int":
+        return int
+    elif type_ == "number" or type_ == "float":
+        return float
+    elif type_ == "boolean" or type_ == "bool":
+        return bool
+    elif type_ == "array":
+        items_schema = json_schema.get("items")
+        if items_schema:
+            item_type = json_schema_to_pydantic_type(items_schema)
+            return list[item_type]
+        else:
+            return list
+    elif type_ == "object":
+        # Handle nested models.
+        properties = json_schema.get("properties")
+        if properties:
+            nested_model = json_schema_to_model(json_schema)
+            return nested_model
+        else:
+            return dict
+    elif type_ == "null":
+        return Optional[Any]  # Use Optional[Any] for nullable fields
+    else:
+        raise ValueError(f"Unsupported JSON schema type: {type_}")

+ 86 - 1
backend/utils/tools.py

@@ -1,5 +1,90 @@
 import inspect
 import inspect
-from typing import get_type_hints
+import logging
+from typing import Awaitable, Callable, get_type_hints
+
+from apps.webui.models.tools import Tools
+from apps.webui.models.users import UserModel
+from apps.webui.utils import load_toolkit_module_by_id
+
+from utils.schemas import json_schema_to_model
+
+log = logging.getLogger(__name__)
+
+
+def apply_extra_params_to_tool_function(
+    function: Callable, extra_params: dict
+) -> Callable[..., Awaitable]:
+    sig = inspect.signature(function)
+    extra_params = {
+        key: value for key, value in extra_params.items() if key in sig.parameters
+    }
+    is_coroutine = inspect.iscoroutinefunction(function)
+
+    async def new_function(**kwargs):
+        extra_kwargs = kwargs | extra_params
+        if is_coroutine:
+            return await function(**extra_kwargs)
+        return function(**extra_kwargs)
+
+    return new_function
+
+
+# Mutation on extra_params
+def get_tools(
+    webui_app, tool_ids: list[str], user: UserModel, extra_params: dict
+) -> dict[str, dict]:
+    tools = {}
+    for tool_id in tool_ids:
+        toolkit = Tools.get_tool_by_id(tool_id)
+        if toolkit is None:
+            continue
+
+        module = webui_app.state.TOOLS.get(tool_id, None)
+        if module is None:
+            module, _ = load_toolkit_module_by_id(tool_id)
+            webui_app.state.TOOLS[tool_id] = module
+
+        extra_params["__id__"] = tool_id
+        if hasattr(module, "valves") and hasattr(module, "Valves"):
+            valves = Tools.get_tool_valves_by_id(tool_id) or {}
+            module.valves = module.Valves(**valves)
+
+        if hasattr(module, "UserValves"):
+            extra_params["__user__"]["valves"] = module.UserValves(  # type: ignore
+                **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
+            )
+
+        for spec in toolkit.specs:
+            # TODO: Fix hack for OpenAI API
+            for val in spec.get("parameters", {}).get("properties", {}).values():
+                if val["type"] == "str":
+                    val["type"] = "string"
+            function_name = spec["name"]
+
+            # convert to function that takes only model params and inserts custom params
+            original_func = getattr(module, function_name)
+            callable = apply_extra_params_to_tool_function(original_func, extra_params)
+            if hasattr(original_func, "__doc__"):
+                callable.__doc__ = original_func.__doc__
+
+            # TODO: This needs to be a pydantic model
+            tool_dict = {
+                "toolkit_id": tool_id,
+                "callable": callable,
+                "spec": spec,
+                "pydantic_model": json_schema_to_model(spec),
+                "file_handler": hasattr(module, "file_handler") and module.file_handler,
+                "citation": hasattr(module, "citation") and module.citation,
+            }
+
+            # TODO: if collision, prepend toolkit name
+            if function_name in tools:
+                log.warning(f"Tool {function_name} already exists in another toolkit!")
+                log.warning(f"Collision between {toolkit} and {tool_id}.")
+                log.warning(f"Discarding {toolkit}.{function_name}")
+            else:
+                tools[function_name] = tool_dict
+    return tools
 
 
 
 
 def doc_to_dict(docstring):
 def doc_to_dict(docstring):