소스 검색

Merge pull request #4602 from michaelpoluektov/tools-refac-1

refactor, perf: Tools refactor (progress PR 1)
Timothy Jaeryang Baek 8 달 전
부모
커밋
cbb0940ff8
3개의 변경된 파일350개의 추가작업 그리고 447개의 파일을 삭제
  1. 3 10
      backend/apps/webui/models/users.py
  2. 346 436
      backend/main.py
  3. 1 1
      backend/utils/task.py

+ 3 - 10
backend/apps/webui/models/users.py

@@ -1,12 +1,10 @@
-from pydantic import BaseModel, ConfigDict, parse_obj_as
-from typing import Union, Optional
+from pydantic import BaseModel, ConfigDict
+from typing import Optional
 import time
 
 from sqlalchemy import String, Column, BigInteger, Text
 
-from utils.misc import get_gravatar_url
-
-from apps.webui.internal.db import Base, JSONField, Session, get_db
+from apps.webui.internal.db import Base, JSONField, get_db
 from apps.webui.models.chats import Chats
 
 ####################
@@ -78,7 +76,6 @@ class UserUpdateForm(BaseModel):
 
 
 class UsersTable:
-
     def insert_new_user(
         self,
         id: str,
@@ -122,7 +119,6 @@ class UsersTable:
     def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
         try:
             with get_db() as db:
-
                 user = db.query(User).filter_by(api_key=api_key).first()
                 return UserModel.model_validate(user)
         except Exception:
@@ -131,7 +127,6 @@ class UsersTable:
     def get_user_by_email(self, email: str) -> Optional[UserModel]:
         try:
             with get_db() as db:
-
                 user = db.query(User).filter_by(email=email).first()
                 return UserModel.model_validate(user)
         except Exception:
@@ -140,7 +135,6 @@ class UsersTable:
     def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
         try:
             with get_db() as db:
-
                 user = db.query(User).filter_by(oauth_sub=sub).first()
                 return UserModel.model_validate(user)
         except Exception:
@@ -195,7 +189,6 @@ class UsersTable:
     def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
         try:
             with get_db() as db:
-
                 db.query(User).filter_by(id=id).update(
                     {"last_active_at": int(time.time())}
                 )

+ 346 - 436
backend/main.py

@@ -51,13 +51,13 @@ from apps.webui.internal.db import Session
 
 
 from pydantic import BaseModel
-from typing import Optional
+from typing import Optional, Callable, Awaitable
 
 from apps.webui.models.auths import Auths
 from apps.webui.models.models import Models
 from apps.webui.models.tools import Tools
 from apps.webui.models.functions import Functions
-from apps.webui.models.users import Users
+from apps.webui.models.users import Users, UserModel
 
 from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
 
@@ -72,7 +72,7 @@ from utils.utils import (
 from utils.task import (
     title_generation_template,
     search_query_generation_template,
-    tools_function_calling_generation_template,
+    tool_calling_generation_template,
 )
 from utils.misc import (
     get_last_user_message,
@@ -261,6 +261,7 @@ def get_filter_function_ids(model):
     def get_priority(function_id):
         function = Functions.get_function_by_id(function_id)
         if function is not None and hasattr(function, "valves"):
+            # TODO: Fix FunctionModel
             return (function.valves if function.valves else {}).get("priority", 0)
         return 0
 
@@ -282,164 +283,42 @@ def get_filter_function_ids(model):
     return filter_ids
 
 
-async def get_function_call_response(
-    messages,
-    files,
-    tool_id,
-    template,
-    task_model_id,
-    user,
-    __event_emitter__=None,
-    __event_call__=None,
-):
-    tool = Tools.get_tool_by_id(tool_id)
-    tools_specs = json.dumps(tool.specs, indent=2)
-    content = tools_function_calling_generation_template(template, tools_specs)
+async def get_content_from_response(response) -> Optional[str]:
+    content = None
+    if hasattr(response, "body_iterator"):
+        async for chunk in response.body_iterator:
+            data = json.loads(chunk.decode("utf-8"))
+            content = data["choices"][0]["message"]["content"]
+
+        # Cleanup any remaining background tasks if necessary
+        if response.background is not None:
+            await response.background()
+    else:
+        content = response["choices"][0]["message"]["content"]
+    return content
+
 
+def get_tool_call_payload(messages, task_model_id, content):
     user_message = get_last_user_message(messages)
-    prompt = (
-        "History:\n"
-        + "\n".join(
-            [
-                f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
-                for message in messages[::-1][:4]
-            ]
-        )
-        + f"\nQuery: {user_message}"
+    history = "\n".join(
+        f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
+        for message in messages[::-1][:4]
     )
 
-    print(prompt)
+    prompt = f"History:\n{history}\nQuery: {user_message}"
 
-    payload = {
+    return {
         "model": task_model_id,
         "messages": [
             {"role": "system", "content": content},
             {"role": "user", "content": f"Query: {prompt}"},
         ],
         "stream": False,
-        "task": str(TASKS.FUNCTION_CALLING),
+        "metadata": {"task": str(TASKS.FUNCTION_CALLING)},
     }
 
-    try:
-        payload = filter_pipeline(payload, user)
-    except Exception as e:
-        raise e
-
-    model = app.state.MODELS[task_model_id]
-
-    response = None
-    try:
-        response = await generate_chat_completions(form_data=payload, user=user)
-        content = None
-
-        if hasattr(response, "body_iterator"):
-            async for chunk in response.body_iterator:
-                data = json.loads(chunk.decode("utf-8"))
-                content = data["choices"][0]["message"]["content"]
-
-            # Cleanup any remaining background tasks if necessary
-            if response.background is not None:
-                await response.background()
-        else:
-            content = response["choices"][0]["message"]["content"]
-
-        if content is None:
-            return None, None, False
-
-        # Parse the function response
-        print(f"content: {content}")
-        result = json.loads(content)
-        print(result)
-
-        citation = None
-
-        if "name" not in result:
-            return None, None, False
-
-        # Call the function
-        if tool_id in webui_app.state.TOOLS:
-            toolkit_module = webui_app.state.TOOLS[tool_id]
-        else:
-            toolkit_module, _ = load_toolkit_module_by_id(tool_id)
-            webui_app.state.TOOLS[tool_id] = toolkit_module
-
-        file_handler = False
-        # check if toolkit_module has file_handler self variable
-        if hasattr(toolkit_module, "file_handler"):
-            file_handler = True
-            print("file_handler: ", file_handler)
-
-        if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"):
-            valves = Tools.get_tool_valves_by_id(tool_id)
-            toolkit_module.valves = toolkit_module.Valves(**(valves if valves else {}))
-
-        function = getattr(toolkit_module, result["name"])
-        function_result = None
-        try:
-            # Get the signature of the function
-            sig = inspect.signature(function)
-            params = result["parameters"]
-
-            # Extra parameters to be passed to the function
-            extra_params = {
-                "__model__": model,
-                "__id__": tool_id,
-                "__messages__": messages,
-                "__files__": files,
-                "__event_emitter__": __event_emitter__,
-                "__event_call__": __event_call__,
-            }
-
-            # Add extra params in contained in function signature
-            for key, value in extra_params.items():
-                if key in sig.parameters:
-                    params[key] = value
-
-            if "__user__" in sig.parameters:
-                # Call the function with the '__user__' parameter included
-                __user__ = {
-                    "id": user.id,
-                    "email": user.email,
-                    "name": user.name,
-                    "role": user.role,
-                }
-
-                try:
-                    if hasattr(toolkit_module, "UserValves"):
-                        __user__["valves"] = toolkit_module.UserValves(
-                            **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
-                        )
-                except Exception as e:
-                    print(e)
-
-                params = {**params, "__user__": __user__}
-
-            if inspect.iscoroutinefunction(function):
-                function_result = await function(**params)
-            else:
-                function_result = function(**params)
-
-            if hasattr(toolkit_module, "citation") and toolkit_module.citation:
-                citation = {
-                    "source": {"name": f"TOOL:{tool.name}/{result['name']}"},
-                    "document": [function_result],
-                    "metadata": [{"source": result["name"]}],
-                }
-        except Exception as e:
-            print(e)
-
-        # Add the function result to the system prompt
-        if function_result is not None:
-            return function_result, citation, file_handler
-    except Exception as e:
-        print(f"Error: {e}")
-
-    return None, None, False
 
-
-async def chat_completion_functions_handler(
-    body, model, user, __event_emitter__, __event_call__
-):
+async def chat_completion_inlets_handler(body, model, extra_params):
     skip_files = None
 
     filter_ids = get_filter_function_ids(model)
@@ -475,37 +354,20 @@ async def chat_completion_functions_handler(
             params = {"body": body}
 
             # Extra parameters to be passed to the function
-            extra_params = {
-                "__model__": model,
-                "__id__": filter_id,
-                "__event_emitter__": __event_emitter__,
-                "__event_call__": __event_call__,
-            }
-
-            # Add extra params in contained in function signature
-            for key, value in extra_params.items():
-                if key in sig.parameters:
-                    params[key] = value
-
-            if "__user__" in sig.parameters:
-                __user__ = {
-                    "id": user.id,
-                    "email": user.email,
-                    "name": user.name,
-                    "role": user.role,
-                }
-
+            custom_params = {**extra_params, "__model__": model, "__id__": filter_id}
+            if hasattr(function_module, "UserValves") and "__user__" in sig.parameters:
                 try:
-                    if hasattr(function_module, "UserValves"):
-                        __user__["valves"] = function_module.UserValves(
-                            **Functions.get_user_valves_by_id_and_user_id(
-                                filter_id, user.id
-                            )
-                        )
+                    uid = custom_params["__user__"]["id"]
+                    custom_params["__user__"]["valves"] = function_module.UserValves(
+                        **Functions.get_user_valves_by_id_and_user_id(filter_id, uid)
+                    )
                 except Exception as e:
                     print(e)
 
-                params = {**params, "__user__": __user__}
+            # Add extra params in contained in function signature
+            for key, value in custom_params.items():
+                if key in sig.parameters:
+                    params[key] = value
 
             if inspect.iscoroutinefunction(inlet):
                 body = await inlet(**params)
@@ -516,74 +378,171 @@ async def chat_completion_functions_handler(
             print(f"Error: {e}")
             raise e
 
-    if skip_files:
-        if "files" in body:
-            del body["files"]
+    if skip_files and "files" in body:
+        del body["files"]
 
     return body, {}
 
 
-async def chat_completion_tools_handler(body, user, __event_emitter__, __event_call__):
-    skip_files = None
+def get_tool_with_custom_params(
+    tool: Callable, custom_params: dict
+) -> Callable[..., Awaitable]:
+    sig = inspect.signature(tool)
+    extra_params = {
+        key: value for key, value in custom_params.items() if key in sig.parameters
+    }
+    is_coroutine = inspect.iscoroutinefunction(tool)
 
-    contexts = []
-    citations = None
+    async def new_tool(**kwargs):
+        extra_kwargs = kwargs | extra_params
+        if is_coroutine:
+            return await tool(**extra_kwargs)
+        return tool(**extra_kwargs)
+
+    return new_tool
+
+
+# Mutation on extra_params
+def get_configured_tools(
+    tool_ids: list[str], extra_params: dict, user: UserModel
+) -> dict[str, dict]:
+    tools = {}
+    for tool_id in tool_ids:
+        toolkit = Tools.get_tool_by_id(tool_id)
+        if toolkit is None:
+            continue
 
+        module = webui_app.state.TOOLS.get(tool_id, None)
+        if module is None:
+            module, _ = load_toolkit_module_by_id(tool_id)
+            webui_app.state.TOOLS[tool_id] = module
+
+        extra_params["__id__"] = tool_id
+        has_citation = hasattr(module, "citation") and module.citation
+        handles_files = hasattr(module, "file_handler") and module.file_handler
+        if hasattr(module, "valves") and hasattr(module, "Valves"):
+            valves = Tools.get_tool_valves_by_id(tool_id) or {}
+            module.valves = module.Valves(**valves)
+
+        if hasattr(module, "UserValves"):
+            extra_params["__user__"]["valves"] = module.UserValves(  # type: ignore
+                **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
+            )
+
+        for spec in toolkit.specs:
+            # TODO: Fix hack for OpenAI API
+            for val in spec.get("parameters", {}).get("properties", {}).values():
+                if val["type"] == "str":
+                    val["type"] = "string"
+            name = spec["name"]
+            callable = getattr(module, name)
+
+            # convert to function that takes only model params and inserts custom params
+            custom_callable = get_tool_with_custom_params(callable, extra_params)
+
+            # TODO: This needs to be a pydantic model
+            tool_dict = {
+                "spec": spec,
+                "citation": has_citation,
+                "file_handler": handles_files,
+                "toolkit_id": tool_id,
+                "callable": custom_callable,
+            }
+            # TODO: if collision, prepend toolkit name
+            if name in tools:
+                log.warning(f"Tool {name} already exists in another toolkit!")
+                log.warning(f"Collision between {toolkit} and {tool_id}.")
+                log.warning(f"Discarding {toolkit}.{name}")
+            else:
+                tools[name] = tool_dict
+
+    return tools
+
+
+async def chat_completion_tools_handler(
+    body: dict, user: UserModel, extra_params: dict
+) -> tuple[dict, dict]:
+    skip_files = False
+    contexts = []
+    citations = []
     task_model_id = get_task_model_id(body["model"])
 
     # If tool_ids field is present, call the functions
-    if "tool_ids" in body:
-        print(body["tool_ids"])
-        for tool_id in body["tool_ids"]:
-            print(tool_id)
-            try:
-                response, citation, file_handler = await get_function_call_response(
-                    messages=body["messages"],
-                    files=body.get("files", []),
-                    tool_id=tool_id,
-                    template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
-                    task_model_id=task_model_id,
-                    user=user,
-                    __event_emitter__=__event_emitter__,
-                    __event_call__=__event_call__,
-                )
+    tool_ids = body.pop("tool_ids", None)
+    if not tool_ids:
+        return body, {}
+
+    log.debug(f"{tool_ids=}")
+    custom_params = {
+        **extra_params,
+        "__model__": app.state.MODELS[task_model_id],
+        "__messages__": body["messages"],
+        "__files__": body.get("files", []),
+    }
+    configured_tools = get_configured_tools(tool_ids, custom_params, user)
 
-                print(file_handler)
-                if isinstance(response, str):
-                    contexts.append(response)
+    log.info(f"{configured_tools=}")
 
-                if citation:
-                    if citations is None:
-                        citations = [citation]
-                    else:
-                        citations.append(citation)
+    specs = [tool["spec"] for tool in configured_tools.values()]
+    tools_specs = json.dumps(specs)
+    template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
+    content = tool_calling_generation_template(template, tools_specs)
+    payload = get_tool_call_payload(body["messages"], task_model_id, content)
+    try:
+        payload = filter_pipeline(payload, user)
+    except Exception as e:
+        raise e
 
-                if file_handler:
-                    skip_files = True
+    try:
+        response = await generate_chat_completions(form_data=payload, user=user)
+        log.debug(f"{response=}")
+        content = await get_content_from_response(response)
+        log.debug(f"{content=}")
+        if content is None:
+            return body, {}
 
-            except Exception as e:
-                print(f"Error: {e}")
-        del body["tool_ids"]
-        print(f"tool_contexts: {contexts}")
+        result = json.loads(content)
+        tool_name = result.get("name", None)
+        if tool_name not in configured_tools:
+            return body, {}
 
-    if skip_files:
-        if "files" in body:
-            del body["files"]
+        tool_params = result.get("parameters", {})
+        toolkit_id = configured_tools[tool_name]["toolkit_id"]
+        try:
+            tool_output = await configured_tools[tool_name]["callable"](**tool_params)
+        except Exception as e:
+            tool_output = str(e)
+        if configured_tools[tool_name]["citation"]:
+            citations.append(
+                {
+                    "source": {"name": f"TOOL:{toolkit_id}/{tool_name}"},
+                    "document": [tool_output],
+                    "metadata": [{"source": tool_name}],
+                }
+            )
+        if configured_tools[tool_name]["file_handler"]:
+            skip_files = True
 
-    return body, {
-        **({"contexts": contexts} if contexts is not None else {}),
-        **({"citations": citations} if citations is not None else {}),
-    }
+        if isinstance(tool_output, str):
+            contexts.append(tool_output)
 
+    except Exception as e:
+        print(f"Error: {e}")
+        content = None
 
-async def chat_completion_files_handler(body):
-    contexts = []
-    citations = None
+    log.debug(f"tool_contexts: {contexts}")
 
-    if "files" in body:
-        files = body["files"]
+    if skip_files and "files" in body:
         del body["files"]
 
+    return body, {"contexts": contexts, "citations": citations}
+
+
+async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]:
+    contexts = []
+    citations = []
+
+    if files := body.pop("files", None):
         contexts, citations = get_rag_context(
             files=files,
             messages=body["messages"],
@@ -596,134 +555,130 @@ async def chat_completion_files_handler(body):
 
         log.debug(f"rag_contexts: {contexts}, citations: {citations}")
 
-    return body, {
-        **({"contexts": contexts} if contexts is not None else {}),
-        **({"citations": citations} if citations is not None else {}),
-    }
-
+    return body, {"contexts": contexts, "citations": citations}
 
-class ChatCompletionMiddleware(BaseHTTPMiddleware):
-    async def dispatch(self, request: Request, call_next):
-        if request.method == "POST" and any(
-            endpoint in request.url.path
-            for endpoint in ["/ollama/api/chat", "/chat/completions"]
-        ):
-            log.debug(f"request.url.path: {request.url.path}")
 
-            try:
-                body, model, user = await get_body_and_model_and_user(request)
-            except Exception as e:
-                return JSONResponse(
-                    status_code=status.HTTP_400_BAD_REQUEST,
-                    content={"detail": str(e)},
-                )
-
-            metadata = {
-                "chat_id": body.pop("chat_id", None),
-                "message_id": body.pop("id", None),
-                "session_id": body.pop("session_id", None),
-                "valves": body.pop("valves", None),
-            }
-
-            __event_emitter__ = get_event_emitter(metadata)
-            __event_call__ = get_event_call(metadata)
+def is_chat_completion_request(request):
+    return request.method == "POST" and any(
+        endpoint in request.url.path
+        for endpoint in ["/ollama/api/chat", "/chat/completions"]
+    )
 
-            # Initialize data_items to store additional data to be sent to the client
-            data_items = []
 
-            # Initialize context, and citations
-            contexts = []
-            citations = []
+class ChatCompletionMiddleware(BaseHTTPMiddleware):
+    async def dispatch(self, request: Request, call_next):
+        if not is_chat_completion_request(request):
+            return await call_next(request)
+        log.debug(f"request.url.path: {request.url.path}")
 
-            try:
-                body, flags = await chat_completion_functions_handler(
-                    body, model, user, __event_emitter__, __event_call__
-                )
-            except Exception as e:
-                return JSONResponse(
-                    status_code=status.HTTP_400_BAD_REQUEST,
-                    content={"detail": str(e)},
-                )
+        try:
+            body, model, user = await get_body_and_model_and_user(request)
+        except Exception as e:
+            return JSONResponse(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                content={"detail": str(e)},
+            )
 
-            try:
-                body, flags = await chat_completion_tools_handler(
-                    body, user, __event_emitter__, __event_call__
-                )
+        metadata = {
+            "chat_id": body.pop("chat_id", None),
+            "message_id": body.pop("id", None),
+            "session_id": body.pop("session_id", None),
+            "valves": body.pop("valves", None),
+        }
 
-                contexts.extend(flags.get("contexts", []))
-                citations.extend(flags.get("citations", []))
-            except Exception as e:
-                print(e)
-                pass
+        __user__ = {
+            "id": user.id,
+            "email": user.email,
+            "name": user.name,
+            "role": user.role,
+        }
 
-            try:
-                body, flags = await chat_completion_files_handler(body)
+        extra_params = {
+            "__user__": __user__,
+            "__event_emitter__": get_event_emitter(metadata),
+            "__event_call__": get_event_call(metadata),
+        }
 
-                contexts.extend(flags.get("contexts", []))
-                citations.extend(flags.get("citations", []))
-            except Exception as e:
-                print(e)
-                pass
+        # Initialize data_items to store additional data to be sent to the client
+        # Initalize contexts and citation
+        data_items = []
+        contexts = []
+        citations = []
 
-            # If context is not empty, insert it into the messages
-            if len(contexts) > 0:
-                context_string = "/n".join(contexts).strip()
-                prompt = get_last_user_message(body["messages"])
-
-                # Workaround for Ollama 2.0+ system prompt issue
-                # TODO: replace with add_or_update_system_message
-                if model["owned_by"] == "ollama":
-                    body["messages"] = prepend_to_first_user_message_content(
-                        rag_template(
-                            rag_app.state.config.RAG_TEMPLATE, context_string, prompt
-                        ),
-                        body["messages"],
-                    )
-                else:
-                    body["messages"] = add_or_update_system_message(
-                        rag_template(
-                            rag_app.state.config.RAG_TEMPLATE, context_string, prompt
-                        ),
-                        body["messages"],
-                    )
+        try:
+            body, flags = await chat_completion_inlets_handler(
+                body, model, extra_params
+            )
+        except Exception as e:
+            return JSONResponse(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                content={"detail": str(e)},
+            )
 
-            # If there are citations, add them to the data_items
-            if len(citations) > 0:
-                data_items.append({"citations": citations})
-
-            body["metadata"] = metadata
-            modified_body_bytes = json.dumps(body).encode("utf-8")
-            # Replace the request body with the modified one
-            request._body = modified_body_bytes
-            # Set custom header to ensure content-length matches new body length
-            request.headers.__dict__["_list"] = [
-                (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
-                *[
-                    (k, v)
-                    for k, v in request.headers.raw
-                    if k.lower() != b"content-length"
-                ],
-            ]
-
-            response = await call_next(request)
-            if isinstance(response, StreamingResponse):
-                # If it's a streaming response, inject it as SSE event or NDJSON line
-                content_type = response.headers.get("Content-Type")
-                if "text/event-stream" in content_type:
-                    return StreamingResponse(
-                        self.openai_stream_wrapper(response.body_iterator, data_items),
-                    )
-                if "application/x-ndjson" in content_type:
-                    return StreamingResponse(
-                        self.ollama_stream_wrapper(response.body_iterator, data_items),
-                    )
+        try:
+            body, flags = await chat_completion_tools_handler(body, user, extra_params)
+            contexts.extend(flags.get("contexts", []))
+            citations.extend(flags.get("citations", []))
+        except Exception as e:
+            log.exception(e)
 
-                return response
+        try:
+            body, flags = await chat_completion_files_handler(body)
+            contexts.extend(flags.get("contexts", []))
+            citations.extend(flags.get("citations", []))
+        except Exception as e:
+            log.exception(e)
+
+        # If context is not empty, insert it into the messages
+        if len(contexts) > 0:
+            context_string = "/n".join(contexts).strip()
+            prompt = get_last_user_message(body["messages"])
+            if prompt is None:
+                raise Exception("No user message found")
+            # Workaround for Ollama 2.0+ system prompt issue
+            # TODO: replace with add_or_update_system_message
+            if model["owned_by"] == "ollama":
+                body["messages"] = prepend_to_first_user_message_content(
+                    rag_template(
+                        rag_app.state.config.RAG_TEMPLATE, context_string, prompt
+                    ),
+                    body["messages"],
+                )
             else:
-                return response
+                body["messages"] = add_or_update_system_message(
+                    rag_template(
+                        rag_app.state.config.RAG_TEMPLATE, context_string, prompt
+                    ),
+                    body["messages"],
+                )
+
+        # If there are citations, add them to the data_items
+        if len(citations) > 0:
+            data_items.append({"citations": citations})
+
+        body["metadata"] = metadata
+        modified_body_bytes = json.dumps(body).encode("utf-8")
+        # Replace the request body with the modified one
+        request._body = modified_body_bytes
+        # Set custom header to ensure content-length matches new body length
+        request.headers.__dict__["_list"] = [
+            (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
+            *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"],
+        ]
 
-        # If it's not a chat completion request, just pass it through
         response = await call_next(request)
+        if isinstance(response, StreamingResponse):
+            # If it's a streaming response, inject it as SSE event or NDJSON line
+            content_type = response.headers["Content-Type"]
+            if "text/event-stream" in content_type:
+                return StreamingResponse(
+                    self.openai_stream_wrapper(response.body_iterator, data_items),
+                )
+            if "application/x-ndjson" in content_type:
+                return StreamingResponse(
+                    self.ollama_stream_wrapper(response.body_iterator, data_items),
+                )
+
         return response
 
     async def _receive(self, body: bytes):
@@ -790,19 +745,21 @@ def filter_pipeline(payload, user):
             url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
             key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
 
-            if key != "":
-                headers = {"Authorization": f"Bearer {key}"}
-                r = requests.post(
-                    f"{url}/{filter['id']}/filter/inlet",
-                    headers=headers,
-                    json={
-                        "user": user,
-                        "body": payload,
-                    },
-                )
+            if key == "":
+                continue
+
+            headers = {"Authorization": f"Bearer {key}"}
+            r = requests.post(
+                f"{url}/{filter['id']}/filter/inlet",
+                headers=headers,
+                json={
+                    "user": user,
+                    "body": payload,
+                },
+            )
 
-                r.raise_for_status()
-                payload = r.json()
+            r.raise_for_status()
+            payload = r.json()
         except Exception as e:
             # Handle connection error here
             print(f"Connection error: {e}")
@@ -817,44 +774,39 @@ def filter_pipeline(payload, user):
 
 class PipelineMiddleware(BaseHTTPMiddleware):
     async def dispatch(self, request: Request, call_next):
-        if request.method == "POST" and (
-            "/ollama/api/chat" in request.url.path
-            or "/chat/completions" in request.url.path
-        ):
-            log.debug(f"request.url.path: {request.url.path}")
-
-            # Read the original request body
-            body = await request.body()
-            # Decode body to string
-            body_str = body.decode("utf-8")
-            # Parse string to JSON
-            data = json.loads(body_str) if body_str else {}
-
-            user = get_current_user(
-                request,
-                get_http_authorization_cred(request.headers.get("Authorization")),
-            )
+        if not is_chat_completion_request(request):
+            return await call_next(request)
 
-            try:
-                data = filter_pipeline(data, user)
-            except Exception as e:
-                return JSONResponse(
-                    status_code=e.args[0],
-                    content={"detail": e.args[1]},
-                )
+        log.debug(f"request.url.path: {request.url.path}")
 
-            modified_body_bytes = json.dumps(data).encode("utf-8")
-            # Replace the request body with the modified one
-            request._body = modified_body_bytes
-            # Set custom header to ensure content-length matches new body length
-            request.headers.__dict__["_list"] = [
-                (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
-                *[
-                    (k, v)
-                    for k, v in request.headers.raw
-                    if k.lower() != b"content-length"
-                ],
-            ]
+        # Read the original request body
+        body = await request.body()
+        # Decode body to string
+        body_str = body.decode("utf-8")
+        # Parse string to JSON
+        data = json.loads(body_str) if body_str else {}
+
+        user = get_current_user(
+            request,
+            get_http_authorization_cred(request.headers["Authorization"]),
+        )
+
+        try:
+            data = filter_pipeline(data, user)
+        except Exception as e:
+            return JSONResponse(
+                status_code=e.args[0],
+                content={"detail": e.args[1]},
+            )
+
+        modified_body_bytes = json.dumps(data).encode("utf-8")
+        # Replace the request body with the modified one
+        request._body = modified_body_bytes
+        # Set custom header to ensure content-length matches new body length
+        request.headers.__dict__["_list"] = [
+            (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
+            *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"],
+        ]
 
         response = await call_next(request)
         return response
@@ -1019,6 +971,8 @@ async def get_all_models():
         model["actions"] = []
         for action_id in action_ids:
             action = Functions.get_function_by_id(action_id)
+            if action is None:
+                raise Exception(f"Action not found: {action_id}")
 
             if action_id in webui_app.state.FUNCTIONS:
                 function_module = webui_app.state.FUNCTIONS[action_id]
@@ -1099,22 +1053,9 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
         )
     model = app.state.MODELS[model_id]
 
-    # `task` field is used to determine the type of the request, e.g. `title_generation`, `query_generation`, etc.
-    task = None
-    if "task" in form_data:
-        task = form_data["task"]
-        del form_data["task"]
-
-    if task:
-        if "metadata" in form_data:
-            form_data["metadata"]["task"] = task
-        else:
-            form_data["metadata"] = {"task": task}
-
     if model.get("pipe"):
         return await generate_function_chat_completion(form_data, user=user)
     if model["owned_by"] == "ollama":
-        print("generate_ollama_chat_completion")
         return await generate_ollama_chat_completion(form_data, user=user)
     else:
         return await generate_openai_chat_completion(form_data, user=user)
@@ -1198,6 +1139,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
     def get_priority(function_id):
         function = Functions.get_function_by_id(function_id)
         if function is not None and hasattr(function, "valves"):
+            # TODO: Fix FunctionModel to include vavles
             return (function.valves if function.valves else {}).get("priority", 0)
         return 0
 
@@ -1487,7 +1429,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
         "stream": False,
         "max_tokens": 50,
         "chat_id": form_data.get("chat_id", None),
-        "task": str(TASKS.TITLE_GENERATION),
+        "metadata": {"task": str(TASKS.TITLE_GENERATION)},
     }
 
     log.debug(payload)
@@ -1540,7 +1482,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
         "messages": [{"role": "user", "content": content}],
         "stream": False,
         "max_tokens": 30,
-        "task": str(TASKS.QUERY_GENERATION),
+        "metadata": {"task": str(TASKS.QUERY_GENERATION)},
     }
 
     print(payload)
@@ -1597,7 +1539,7 @@ Message: """{{prompt}}"""
         "stream": False,
         "max_tokens": 4,
         "chat_id": form_data.get("chat_id", None),
-        "task": str(TASKS.EMOJI_GENERATION),
+        "metadata": {"task": str(TASKS.EMOJI_GENERATION)},
     }
 
     log.debug(payload)
@@ -1616,41 +1558,6 @@ Message: """{{prompt}}"""
     return await generate_chat_completions(form_data=payload, user=user)
 
 
-@app.post("/api/task/tools/completions")
-async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_user)):
-    print("get_tools_function_calling")
-
-    model_id = form_data["model"]
-    if model_id not in app.state.MODELS:
-        raise HTTPException(
-            status_code=status.HTTP_404_NOT_FOUND,
-            detail="Model not found",
-        )
-
-    # Check if the user has a custom task model
-    # If the user has a custom task model, use that model
-    model_id = get_task_model_id(model_id)
-
-    print(model_id)
-    template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
-
-    try:
-        context, _, _ = await get_function_call_response(
-            form_data["messages"],
-            form_data.get("files", []),
-            form_data["tool_id"],
-            template,
-            model_id,
-            user,
-        )
-        return context
-    except Exception as e:
-        return JSONResponse(
-            status_code=e.args[0],
-            content={"detail": e.args[1]},
-        )
-
-
 ##################################
 #
 # Pipelines Endpoints
@@ -1689,7 +1596,7 @@ async def upload_pipeline(
 ):
     print("upload_pipeline", urlIdx, file.filename)
     # Check if the uploaded file is a python file
-    if not file.filename.endswith(".py"):
+    if not (file.filename and file.filename.endswith(".py")):
         raise HTTPException(
             status_code=status.HTTP_400_BAD_REQUEST,
             detail="Only Python (.py) files are allowed.",
@@ -2138,7 +2045,10 @@ async def oauth_login(provider: str, request: Request):
     redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
         "oauth_callback", provider=provider
     )
-    return await oauth.create_client(provider).authorize_redirect(request, redirect_uri)
+    client = oauth.create_client(provider)
+    if client is None:
+        raise HTTPException(404)
+    return await client.authorize_redirect(request, redirect_uri)
 
 
 # OAuth login logic is as follows:

+ 1 - 1
backend/utils/task.py

@@ -121,6 +121,6 @@ def search_query_generation_template(
     return template
 
 
-def tools_function_calling_generation_template(template: str, tools_specs: str) -> str:
+def tool_calling_generation_template(template: str, tools_specs: str) -> str:
     template = template.replace("{{TOOLS}}", tools_specs)
     return template