瀏覽代碼

Merge pull request #4602 from michaelpoluektov/tools-refac-1

refactor, perf: Tools refactor (progress PR 1)
Timothy Jaeryang Baek 8 月之前
父節點
當前提交
cbb0940ff8
共有 3 個文件被更改,包括 350 次插入447 次删除
  1. 3 10
      backend/apps/webui/models/users.py
  2. 346 436
      backend/main.py
  3. 1 1
      backend/utils/task.py

+ 3 - 10
backend/apps/webui/models/users.py

@@ -1,12 +1,10 @@
-from pydantic import BaseModel, ConfigDict, parse_obj_as
-from typing import Union, Optional
+from pydantic import BaseModel, ConfigDict
+from typing import Optional
 import time
 import time
 
 
 from sqlalchemy import String, Column, BigInteger, Text
 from sqlalchemy import String, Column, BigInteger, Text
 
 
-from utils.misc import get_gravatar_url
-
-from apps.webui.internal.db import Base, JSONField, Session, get_db
+from apps.webui.internal.db import Base, JSONField, get_db
 from apps.webui.models.chats import Chats
 from apps.webui.models.chats import Chats
 
 
 ####################
 ####################
@@ -78,7 +76,6 @@ class UserUpdateForm(BaseModel):
 
 
 
 
 class UsersTable:
 class UsersTable:
-
     def insert_new_user(
     def insert_new_user(
         self,
         self,
         id: str,
         id: str,
@@ -122,7 +119,6 @@ class UsersTable:
     def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
     def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
         try:
         try:
             with get_db() as db:
             with get_db() as db:
-
                 user = db.query(User).filter_by(api_key=api_key).first()
                 user = db.query(User).filter_by(api_key=api_key).first()
                 return UserModel.model_validate(user)
                 return UserModel.model_validate(user)
         except Exception:
         except Exception:
@@ -131,7 +127,6 @@ class UsersTable:
     def get_user_by_email(self, email: str) -> Optional[UserModel]:
     def get_user_by_email(self, email: str) -> Optional[UserModel]:
         try:
         try:
             with get_db() as db:
             with get_db() as db:
-
                 user = db.query(User).filter_by(email=email).first()
                 user = db.query(User).filter_by(email=email).first()
                 return UserModel.model_validate(user)
                 return UserModel.model_validate(user)
         except Exception:
         except Exception:
@@ -140,7 +135,6 @@ class UsersTable:
     def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
     def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
         try:
         try:
             with get_db() as db:
             with get_db() as db:
-
                 user = db.query(User).filter_by(oauth_sub=sub).first()
                 user = db.query(User).filter_by(oauth_sub=sub).first()
                 return UserModel.model_validate(user)
                 return UserModel.model_validate(user)
         except Exception:
         except Exception:
@@ -195,7 +189,6 @@ class UsersTable:
     def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
     def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
         try:
         try:
             with get_db() as db:
             with get_db() as db:
-
                 db.query(User).filter_by(id=id).update(
                 db.query(User).filter_by(id=id).update(
                     {"last_active_at": int(time.time())}
                     {"last_active_at": int(time.time())}
                 )
                 )

+ 346 - 436
backend/main.py

@@ -51,13 +51,13 @@ from apps.webui.internal.db import Session
 
 
 
 
 from pydantic import BaseModel
 from pydantic import BaseModel
-from typing import Optional
+from typing import Optional, Callable, Awaitable
 
 
 from apps.webui.models.auths import Auths
 from apps.webui.models.auths import Auths
 from apps.webui.models.models import Models
 from apps.webui.models.models import Models
 from apps.webui.models.tools import Tools
 from apps.webui.models.tools import Tools
 from apps.webui.models.functions import Functions
 from apps.webui.models.functions import Functions
-from apps.webui.models.users import Users
+from apps.webui.models.users import Users, UserModel
 
 
 from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
 from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
 
 
@@ -72,7 +72,7 @@ from utils.utils import (
 from utils.task import (
 from utils.task import (
     title_generation_template,
     title_generation_template,
     search_query_generation_template,
     search_query_generation_template,
-    tools_function_calling_generation_template,
+    tool_calling_generation_template,
 )
 )
 from utils.misc import (
 from utils.misc import (
     get_last_user_message,
     get_last_user_message,
@@ -261,6 +261,7 @@ def get_filter_function_ids(model):
     def get_priority(function_id):
     def get_priority(function_id):
         function = Functions.get_function_by_id(function_id)
         function = Functions.get_function_by_id(function_id)
         if function is not None and hasattr(function, "valves"):
         if function is not None and hasattr(function, "valves"):
+            # TODO: Fix FunctionModel
             return (function.valves if function.valves else {}).get("priority", 0)
             return (function.valves if function.valves else {}).get("priority", 0)
         return 0
         return 0
 
 
@@ -282,164 +283,42 @@ def get_filter_function_ids(model):
     return filter_ids
     return filter_ids
 
 
 
 
-async def get_function_call_response(
-    messages,
-    files,
-    tool_id,
-    template,
-    task_model_id,
-    user,
-    __event_emitter__=None,
-    __event_call__=None,
-):
-    tool = Tools.get_tool_by_id(tool_id)
-    tools_specs = json.dumps(tool.specs, indent=2)
-    content = tools_function_calling_generation_template(template, tools_specs)
+async def get_content_from_response(response) -> Optional[str]:
+    content = None
+    if hasattr(response, "body_iterator"):
+        async for chunk in response.body_iterator:
+            data = json.loads(chunk.decode("utf-8"))
+            content = data["choices"][0]["message"]["content"]
+
+        # Cleanup any remaining background tasks if necessary
+        if response.background is not None:
+            await response.background()
+    else:
+        content = response["choices"][0]["message"]["content"]
+    return content
+
 
 
+def get_tool_call_payload(messages, task_model_id, content):
     user_message = get_last_user_message(messages)
     user_message = get_last_user_message(messages)
-    prompt = (
-        "History:\n"
-        + "\n".join(
-            [
-                f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
-                for message in messages[::-1][:4]
-            ]
-        )
-        + f"\nQuery: {user_message}"
+    history = "\n".join(
+        f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
+        for message in messages[::-1][:4]
     )
     )
 
 
-    print(prompt)
+    prompt = f"History:\n{history}\nQuery: {user_message}"
 
 
-    payload = {
+    return {
         "model": task_model_id,
         "model": task_model_id,
         "messages": [
         "messages": [
             {"role": "system", "content": content},
             {"role": "system", "content": content},
             {"role": "user", "content": f"Query: {prompt}"},
             {"role": "user", "content": f"Query: {prompt}"},
         ],
         ],
         "stream": False,
         "stream": False,
-        "task": str(TASKS.FUNCTION_CALLING),
+        "metadata": {"task": str(TASKS.FUNCTION_CALLING)},
     }
     }
 
 
-    try:
-        payload = filter_pipeline(payload, user)
-    except Exception as e:
-        raise e
-
-    model = app.state.MODELS[task_model_id]
-
-    response = None
-    try:
-        response = await generate_chat_completions(form_data=payload, user=user)
-        content = None
-
-        if hasattr(response, "body_iterator"):
-            async for chunk in response.body_iterator:
-                data = json.loads(chunk.decode("utf-8"))
-                content = data["choices"][0]["message"]["content"]
-
-            # Cleanup any remaining background tasks if necessary
-            if response.background is not None:
-                await response.background()
-        else:
-            content = response["choices"][0]["message"]["content"]
-
-        if content is None:
-            return None, None, False
-
-        # Parse the function response
-        print(f"content: {content}")
-        result = json.loads(content)
-        print(result)
-
-        citation = None
-
-        if "name" not in result:
-            return None, None, False
-
-        # Call the function
-        if tool_id in webui_app.state.TOOLS:
-            toolkit_module = webui_app.state.TOOLS[tool_id]
-        else:
-            toolkit_module, _ = load_toolkit_module_by_id(tool_id)
-            webui_app.state.TOOLS[tool_id] = toolkit_module
-
-        file_handler = False
-        # check if toolkit_module has file_handler self variable
-        if hasattr(toolkit_module, "file_handler"):
-            file_handler = True
-            print("file_handler: ", file_handler)
-
-        if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"):
-            valves = Tools.get_tool_valves_by_id(tool_id)
-            toolkit_module.valves = toolkit_module.Valves(**(valves if valves else {}))
-
-        function = getattr(toolkit_module, result["name"])
-        function_result = None
-        try:
-            # Get the signature of the function
-            sig = inspect.signature(function)
-            params = result["parameters"]
-
-            # Extra parameters to be passed to the function
-            extra_params = {
-                "__model__": model,
-                "__id__": tool_id,
-                "__messages__": messages,
-                "__files__": files,
-                "__event_emitter__": __event_emitter__,
-                "__event_call__": __event_call__,
-            }
-
-            # Add extra params in contained in function signature
-            for key, value in extra_params.items():
-                if key in sig.parameters:
-                    params[key] = value
-
-            if "__user__" in sig.parameters:
-                # Call the function with the '__user__' parameter included
-                __user__ = {
-                    "id": user.id,
-                    "email": user.email,
-                    "name": user.name,
-                    "role": user.role,
-                }
-
-                try:
-                    if hasattr(toolkit_module, "UserValves"):
-                        __user__["valves"] = toolkit_module.UserValves(
-                            **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
-                        )
-                except Exception as e:
-                    print(e)
-
-                params = {**params, "__user__": __user__}
-
-            if inspect.iscoroutinefunction(function):
-                function_result = await function(**params)
-            else:
-                function_result = function(**params)
-
-            if hasattr(toolkit_module, "citation") and toolkit_module.citation:
-                citation = {
-                    "source": {"name": f"TOOL:{tool.name}/{result['name']}"},
-                    "document": [function_result],
-                    "metadata": [{"source": result["name"]}],
-                }
-        except Exception as e:
-            print(e)
-
-        # Add the function result to the system prompt
-        if function_result is not None:
-            return function_result, citation, file_handler
-    except Exception as e:
-        print(f"Error: {e}")
-
-    return None, None, False
 
 
-
-async def chat_completion_functions_handler(
-    body, model, user, __event_emitter__, __event_call__
-):
+async def chat_completion_inlets_handler(body, model, extra_params):
     skip_files = None
     skip_files = None
 
 
     filter_ids = get_filter_function_ids(model)
     filter_ids = get_filter_function_ids(model)
@@ -475,37 +354,20 @@ async def chat_completion_functions_handler(
             params = {"body": body}
             params = {"body": body}
 
 
             # Extra parameters to be passed to the function
             # Extra parameters to be passed to the function
-            extra_params = {
-                "__model__": model,
-                "__id__": filter_id,
-                "__event_emitter__": __event_emitter__,
-                "__event_call__": __event_call__,
-            }
-
-            # Add extra params in contained in function signature
-            for key, value in extra_params.items():
-                if key in sig.parameters:
-                    params[key] = value
-
-            if "__user__" in sig.parameters:
-                __user__ = {
-                    "id": user.id,
-                    "email": user.email,
-                    "name": user.name,
-                    "role": user.role,
-                }
-
+            custom_params = {**extra_params, "__model__": model, "__id__": filter_id}
+            if hasattr(function_module, "UserValves") and "__user__" in sig.parameters:
                 try:
                 try:
-                    if hasattr(function_module, "UserValves"):
-                        __user__["valves"] = function_module.UserValves(
-                            **Functions.get_user_valves_by_id_and_user_id(
-                                filter_id, user.id
-                            )
-                        )
+                    uid = custom_params["__user__"]["id"]
+                    custom_params["__user__"]["valves"] = function_module.UserValves(
+                        **Functions.get_user_valves_by_id_and_user_id(filter_id, uid)
+                    )
                 except Exception as e:
                 except Exception as e:
                     print(e)
                     print(e)
 
 
-                params = {**params, "__user__": __user__}
+            # Add extra params in contained in function signature
+            for key, value in custom_params.items():
+                if key in sig.parameters:
+                    params[key] = value
 
 
             if inspect.iscoroutinefunction(inlet):
             if inspect.iscoroutinefunction(inlet):
                 body = await inlet(**params)
                 body = await inlet(**params)
@@ -516,74 +378,171 @@ async def chat_completion_functions_handler(
             print(f"Error: {e}")
             print(f"Error: {e}")
             raise e
             raise e
 
 
-    if skip_files:
-        if "files" in body:
-            del body["files"]
+    if skip_files and "files" in body:
+        del body["files"]
 
 
     return body, {}
     return body, {}
 
 
 
 
-async def chat_completion_tools_handler(body, user, __event_emitter__, __event_call__):
-    skip_files = None
+def get_tool_with_custom_params(
+    tool: Callable, custom_params: dict
+) -> Callable[..., Awaitable]:
+    sig = inspect.signature(tool)
+    extra_params = {
+        key: value for key, value in custom_params.items() if key in sig.parameters
+    }
+    is_coroutine = inspect.iscoroutinefunction(tool)
 
 
-    contexts = []
-    citations = None
+    async def new_tool(**kwargs):
+        extra_kwargs = kwargs | extra_params
+        if is_coroutine:
+            return await tool(**extra_kwargs)
+        return tool(**extra_kwargs)
+
+    return new_tool
+
+
+# Mutation on extra_params
+def get_configured_tools(
+    tool_ids: list[str], extra_params: dict, user: UserModel
+) -> dict[str, dict]:
+    tools = {}
+    for tool_id in tool_ids:
+        toolkit = Tools.get_tool_by_id(tool_id)
+        if toolkit is None:
+            continue
 
 
+        module = webui_app.state.TOOLS.get(tool_id, None)
+        if module is None:
+            module, _ = load_toolkit_module_by_id(tool_id)
+            webui_app.state.TOOLS[tool_id] = module
+
+        extra_params["__id__"] = tool_id
+        has_citation = hasattr(module, "citation") and module.citation
+        handles_files = hasattr(module, "file_handler") and module.file_handler
+        if hasattr(module, "valves") and hasattr(module, "Valves"):
+            valves = Tools.get_tool_valves_by_id(tool_id) or {}
+            module.valves = module.Valves(**valves)
+
+        if hasattr(module, "UserValves"):
+            extra_params["__user__"]["valves"] = module.UserValves(  # type: ignore
+                **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
+            )
+
+        for spec in toolkit.specs:
+            # TODO: Fix hack for OpenAI API
+            for val in spec.get("parameters", {}).get("properties", {}).values():
+                if val["type"] == "str":
+                    val["type"] = "string"
+            name = spec["name"]
+            callable = getattr(module, name)
+
+            # convert to function that takes only model params and inserts custom params
+            custom_callable = get_tool_with_custom_params(callable, extra_params)
+
+            # TODO: This needs to be a pydantic model
+            tool_dict = {
+                "spec": spec,
+                "citation": has_citation,
+                "file_handler": handles_files,
+                "toolkit_id": tool_id,
+                "callable": custom_callable,
+            }
+            # TODO: if collision, prepend toolkit name
+            if name in tools:
+                log.warning(f"Tool {name} already exists in another toolkit!")
+                log.warning(f"Collision between {toolkit} and {tool_id}.")
+                log.warning(f"Discarding {toolkit}.{name}")
+            else:
+                tools[name] = tool_dict
+
+    return tools
+
+
+async def chat_completion_tools_handler(
+    body: dict, user: UserModel, extra_params: dict
+) -> tuple[dict, dict]:
+    skip_files = False
+    contexts = []
+    citations = []
     task_model_id = get_task_model_id(body["model"])
     task_model_id = get_task_model_id(body["model"])
 
 
     # If tool_ids field is present, call the functions
     # If tool_ids field is present, call the functions
-    if "tool_ids" in body:
-        print(body["tool_ids"])
-        for tool_id in body["tool_ids"]:
-            print(tool_id)
-            try:
-                response, citation, file_handler = await get_function_call_response(
-                    messages=body["messages"],
-                    files=body.get("files", []),
-                    tool_id=tool_id,
-                    template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
-                    task_model_id=task_model_id,
-                    user=user,
-                    __event_emitter__=__event_emitter__,
-                    __event_call__=__event_call__,
-                )
+    tool_ids = body.pop("tool_ids", None)
+    if not tool_ids:
+        return body, {}
+
+    log.debug(f"{tool_ids=}")
+    custom_params = {
+        **extra_params,
+        "__model__": app.state.MODELS[task_model_id],
+        "__messages__": body["messages"],
+        "__files__": body.get("files", []),
+    }
+    configured_tools = get_configured_tools(tool_ids, custom_params, user)
 
 
-                print(file_handler)
-                if isinstance(response, str):
-                    contexts.append(response)
+    log.info(f"{configured_tools=}")
 
 
-                if citation:
-                    if citations is None:
-                        citations = [citation]
-                    else:
-                        citations.append(citation)
+    specs = [tool["spec"] for tool in configured_tools.values()]
+    tools_specs = json.dumps(specs)
+    template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
+    content = tool_calling_generation_template(template, tools_specs)
+    payload = get_tool_call_payload(body["messages"], task_model_id, content)
+    try:
+        payload = filter_pipeline(payload, user)
+    except Exception as e:
+        raise e
 
 
-                if file_handler:
-                    skip_files = True
+    try:
+        response = await generate_chat_completions(form_data=payload, user=user)
+        log.debug(f"{response=}")
+        content = await get_content_from_response(response)
+        log.debug(f"{content=}")
+        if content is None:
+            return body, {}
 
 
-            except Exception as e:
-                print(f"Error: {e}")
-        del body["tool_ids"]
-        print(f"tool_contexts: {contexts}")
+        result = json.loads(content)
+        tool_name = result.get("name", None)
+        if tool_name not in configured_tools:
+            return body, {}
 
 
-    if skip_files:
-        if "files" in body:
-            del body["files"]
+        tool_params = result.get("parameters", {})
+        toolkit_id = configured_tools[tool_name]["toolkit_id"]
+        try:
+            tool_output = await configured_tools[tool_name]["callable"](**tool_params)
+        except Exception as e:
+            tool_output = str(e)
+        if configured_tools[tool_name]["citation"]:
+            citations.append(
+                {
+                    "source": {"name": f"TOOL:{toolkit_id}/{tool_name}"},
+                    "document": [tool_output],
+                    "metadata": [{"source": tool_name}],
+                }
+            )
+        if configured_tools[tool_name]["file_handler"]:
+            skip_files = True
 
 
-    return body, {
-        **({"contexts": contexts} if contexts is not None else {}),
-        **({"citations": citations} if citations is not None else {}),
-    }
+        if isinstance(tool_output, str):
+            contexts.append(tool_output)
 
 
+    except Exception as e:
+        print(f"Error: {e}")
+        content = None
 
 
-async def chat_completion_files_handler(body):
-    contexts = []
-    citations = None
+    log.debug(f"tool_contexts: {contexts}")
 
 
-    if "files" in body:
-        files = body["files"]
+    if skip_files and "files" in body:
         del body["files"]
         del body["files"]
 
 
+    return body, {"contexts": contexts, "citations": citations}
+
+
+async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]:
+    contexts = []
+    citations = []
+
+    if files := body.pop("files", None):
         contexts, citations = get_rag_context(
         contexts, citations = get_rag_context(
             files=files,
             files=files,
             messages=body["messages"],
             messages=body["messages"],
@@ -596,134 +555,130 @@ async def chat_completion_files_handler(body):
 
 
         log.debug(f"rag_contexts: {contexts}, citations: {citations}")
         log.debug(f"rag_contexts: {contexts}, citations: {citations}")
 
 
-    return body, {
-        **({"contexts": contexts} if contexts is not None else {}),
-        **({"citations": citations} if citations is not None else {}),
-    }
-
+    return body, {"contexts": contexts, "citations": citations}
 
 
-class ChatCompletionMiddleware(BaseHTTPMiddleware):
-    async def dispatch(self, request: Request, call_next):
-        if request.method == "POST" and any(
-            endpoint in request.url.path
-            for endpoint in ["/ollama/api/chat", "/chat/completions"]
-        ):
-            log.debug(f"request.url.path: {request.url.path}")
 
 
-            try:
-                body, model, user = await get_body_and_model_and_user(request)
-            except Exception as e:
-                return JSONResponse(
-                    status_code=status.HTTP_400_BAD_REQUEST,
-                    content={"detail": str(e)},
-                )
-
-            metadata = {
-                "chat_id": body.pop("chat_id", None),
-                "message_id": body.pop("id", None),
-                "session_id": body.pop("session_id", None),
-                "valves": body.pop("valves", None),
-            }
-
-            __event_emitter__ = get_event_emitter(metadata)
-            __event_call__ = get_event_call(metadata)
+def is_chat_completion_request(request):
+    return request.method == "POST" and any(
+        endpoint in request.url.path
+        for endpoint in ["/ollama/api/chat", "/chat/completions"]
+    )
 
 
-            # Initialize data_items to store additional data to be sent to the client
-            data_items = []
 
 
-            # Initialize context, and citations
-            contexts = []
-            citations = []
+class ChatCompletionMiddleware(BaseHTTPMiddleware):
+    async def dispatch(self, request: Request, call_next):
+        if not is_chat_completion_request(request):
+            return await call_next(request)
+        log.debug(f"request.url.path: {request.url.path}")
 
 
-            try:
-                body, flags = await chat_completion_functions_handler(
-                    body, model, user, __event_emitter__, __event_call__
-                )
-            except Exception as e:
-                return JSONResponse(
-                    status_code=status.HTTP_400_BAD_REQUEST,
-                    content={"detail": str(e)},
-                )
+        try:
+            body, model, user = await get_body_and_model_and_user(request)
+        except Exception as e:
+            return JSONResponse(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                content={"detail": str(e)},
+            )
 
 
-            try:
-                body, flags = await chat_completion_tools_handler(
-                    body, user, __event_emitter__, __event_call__
-                )
+        metadata = {
+            "chat_id": body.pop("chat_id", None),
+            "message_id": body.pop("id", None),
+            "session_id": body.pop("session_id", None),
+            "valves": body.pop("valves", None),
+        }
 
 
-                contexts.extend(flags.get("contexts", []))
-                citations.extend(flags.get("citations", []))
-            except Exception as e:
-                print(e)
-                pass
+        __user__ = {
+            "id": user.id,
+            "email": user.email,
+            "name": user.name,
+            "role": user.role,
+        }
 
 
-            try:
-                body, flags = await chat_completion_files_handler(body)
+        extra_params = {
+            "__user__": __user__,
+            "__event_emitter__": get_event_emitter(metadata),
+            "__event_call__": get_event_call(metadata),
+        }
 
 
-                contexts.extend(flags.get("contexts", []))
-                citations.extend(flags.get("citations", []))
-            except Exception as e:
-                print(e)
-                pass
+        # Initialize data_items to store additional data to be sent to the client
+        # Initalize contexts and citation
+        data_items = []
+        contexts = []
+        citations = []
 
 
-            # If context is not empty, insert it into the messages
-            if len(contexts) > 0:
-                context_string = "/n".join(contexts).strip()
-                prompt = get_last_user_message(body["messages"])
-
-                # Workaround for Ollama 2.0+ system prompt issue
-                # TODO: replace with add_or_update_system_message
-                if model["owned_by"] == "ollama":
-                    body["messages"] = prepend_to_first_user_message_content(
-                        rag_template(
-                            rag_app.state.config.RAG_TEMPLATE, context_string, prompt
-                        ),
-                        body["messages"],
-                    )
-                else:
-                    body["messages"] = add_or_update_system_message(
-                        rag_template(
-                            rag_app.state.config.RAG_TEMPLATE, context_string, prompt
-                        ),
-                        body["messages"],
-                    )
+        try:
+            body, flags = await chat_completion_inlets_handler(
+                body, model, extra_params
+            )
+        except Exception as e:
+            return JSONResponse(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                content={"detail": str(e)},
+            )
 
 
-            # If there are citations, add them to the data_items
-            if len(citations) > 0:
-                data_items.append({"citations": citations})
-
-            body["metadata"] = metadata
-            modified_body_bytes = json.dumps(body).encode("utf-8")
-            # Replace the request body with the modified one
-            request._body = modified_body_bytes
-            # Set custom header to ensure content-length matches new body length
-            request.headers.__dict__["_list"] = [
-                (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
-                *[
-                    (k, v)
-                    for k, v in request.headers.raw
-                    if k.lower() != b"content-length"
-                ],
-            ]
-
-            response = await call_next(request)
-            if isinstance(response, StreamingResponse):
-                # If it's a streaming response, inject it as SSE event or NDJSON line
-                content_type = response.headers.get("Content-Type")
-                if "text/event-stream" in content_type:
-                    return StreamingResponse(
-                        self.openai_stream_wrapper(response.body_iterator, data_items),
-                    )
-                if "application/x-ndjson" in content_type:
-                    return StreamingResponse(
-                        self.ollama_stream_wrapper(response.body_iterator, data_items),
-                    )
+        try:
+            body, flags = await chat_completion_tools_handler(body, user, extra_params)
+            contexts.extend(flags.get("contexts", []))
+            citations.extend(flags.get("citations", []))
+        except Exception as e:
+            log.exception(e)
 
 
-                return response
+        try:
+            body, flags = await chat_completion_files_handler(body)
+            contexts.extend(flags.get("contexts", []))
+            citations.extend(flags.get("citations", []))
+        except Exception as e:
+            log.exception(e)
+
+        # If context is not empty, insert it into the messages
+        if len(contexts) > 0:
+            context_string = "/n".join(contexts).strip()
+            prompt = get_last_user_message(body["messages"])
+            if prompt is None:
+                raise Exception("No user message found")
+            # Workaround for Ollama 2.0+ system prompt issue
+            # TODO: replace with add_or_update_system_message
+            if model["owned_by"] == "ollama":
+                body["messages"] = prepend_to_first_user_message_content(
+                    rag_template(
+                        rag_app.state.config.RAG_TEMPLATE, context_string, prompt
+                    ),
+                    body["messages"],
+                )
             else:
             else:
-                return response
+                body["messages"] = add_or_update_system_message(
+                    rag_template(
+                        rag_app.state.config.RAG_TEMPLATE, context_string, prompt
+                    ),
+                    body["messages"],
+                )
+
+        # If there are citations, add them to the data_items
+        if len(citations) > 0:
+            data_items.append({"citations": citations})
+
+        body["metadata"] = metadata
+        modified_body_bytes = json.dumps(body).encode("utf-8")
+        # Replace the request body with the modified one
+        request._body = modified_body_bytes
+        # Set custom header to ensure content-length matches new body length
+        request.headers.__dict__["_list"] = [
+            (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
+            *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"],
+        ]
 
 
-        # If it's not a chat completion request, just pass it through
         response = await call_next(request)
         response = await call_next(request)
+        if isinstance(response, StreamingResponse):
+            # If it's a streaming response, inject it as SSE event or NDJSON line
+            content_type = response.headers["Content-Type"]
+            if "text/event-stream" in content_type:
+                return StreamingResponse(
+                    self.openai_stream_wrapper(response.body_iterator, data_items),
+                )
+            if "application/x-ndjson" in content_type:
+                return StreamingResponse(
+                    self.ollama_stream_wrapper(response.body_iterator, data_items),
+                )
+
         return response
         return response
 
 
     async def _receive(self, body: bytes):
     async def _receive(self, body: bytes):
@@ -790,19 +745,21 @@ def filter_pipeline(payload, user):
             url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
             url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
             key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
             key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
 
 
-            if key != "":
-                headers = {"Authorization": f"Bearer {key}"}
-                r = requests.post(
-                    f"{url}/{filter['id']}/filter/inlet",
-                    headers=headers,
-                    json={
-                        "user": user,
-                        "body": payload,
-                    },
-                )
+            if key == "":
+                continue
+
+            headers = {"Authorization": f"Bearer {key}"}
+            r = requests.post(
+                f"{url}/{filter['id']}/filter/inlet",
+                headers=headers,
+                json={
+                    "user": user,
+                    "body": payload,
+                },
+            )
 
 
-                r.raise_for_status()
-                payload = r.json()
+            r.raise_for_status()
+            payload = r.json()
         except Exception as e:
         except Exception as e:
             # Handle connection error here
             # Handle connection error here
             print(f"Connection error: {e}")
             print(f"Connection error: {e}")
@@ -817,44 +774,39 @@ def filter_pipeline(payload, user):
 
 
 class PipelineMiddleware(BaseHTTPMiddleware):
 class PipelineMiddleware(BaseHTTPMiddleware):
     async def dispatch(self, request: Request, call_next):
     async def dispatch(self, request: Request, call_next):
-        if request.method == "POST" and (
-            "/ollama/api/chat" in request.url.path
-            or "/chat/completions" in request.url.path
-        ):
-            log.debug(f"request.url.path: {request.url.path}")
-
-            # Read the original request body
-            body = await request.body()
-            # Decode body to string
-            body_str = body.decode("utf-8")
-            # Parse string to JSON
-            data = json.loads(body_str) if body_str else {}
-
-            user = get_current_user(
-                request,
-                get_http_authorization_cred(request.headers.get("Authorization")),
-            )
+        if not is_chat_completion_request(request):
+            return await call_next(request)
 
 
-            try:
-                data = filter_pipeline(data, user)
-            except Exception as e:
-                return JSONResponse(
-                    status_code=e.args[0],
-                    content={"detail": e.args[1]},
-                )
+        log.debug(f"request.url.path: {request.url.path}")
 
 
-            modified_body_bytes = json.dumps(data).encode("utf-8")
-            # Replace the request body with the modified one
-            request._body = modified_body_bytes
-            # Set custom header to ensure content-length matches new body length
-            request.headers.__dict__["_list"] = [
-                (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
-                *[
-                    (k, v)
-                    for k, v in request.headers.raw
-                    if k.lower() != b"content-length"
-                ],
-            ]
+        # Read the original request body
+        body = await request.body()
+        # Decode body to string
+        body_str = body.decode("utf-8")
+        # Parse string to JSON
+        data = json.loads(body_str) if body_str else {}
+
+        user = get_current_user(
+            request,
+            get_http_authorization_cred(request.headers["Authorization"]),
+        )
+
+        try:
+            data = filter_pipeline(data, user)
+        except Exception as e:
+            return JSONResponse(
+                status_code=e.args[0],
+                content={"detail": e.args[1]},
+            )
+
+        modified_body_bytes = json.dumps(data).encode("utf-8")
+        # Replace the request body with the modified one
+        request._body = modified_body_bytes
+        # Set custom header to ensure content-length matches new body length
+        request.headers.__dict__["_list"] = [
+            (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
+            *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"],
+        ]
 
 
         response = await call_next(request)
         response = await call_next(request)
         return response
         return response
@@ -1019,6 +971,8 @@ async def get_all_models():
         model["actions"] = []
         model["actions"] = []
         for action_id in action_ids:
         for action_id in action_ids:
             action = Functions.get_function_by_id(action_id)
             action = Functions.get_function_by_id(action_id)
+            if action is None:
+                raise Exception(f"Action not found: {action_id}")
 
 
             if action_id in webui_app.state.FUNCTIONS:
             if action_id in webui_app.state.FUNCTIONS:
                 function_module = webui_app.state.FUNCTIONS[action_id]
                 function_module = webui_app.state.FUNCTIONS[action_id]
@@ -1099,22 +1053,9 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
         )
         )
     model = app.state.MODELS[model_id]
     model = app.state.MODELS[model_id]
 
 
-    # `task` field is used to determine the type of the request, e.g. `title_generation`, `query_generation`, etc.
-    task = None
-    if "task" in form_data:
-        task = form_data["task"]
-        del form_data["task"]
-
-    if task:
-        if "metadata" in form_data:
-            form_data["metadata"]["task"] = task
-        else:
-            form_data["metadata"] = {"task": task}
-
     if model.get("pipe"):
     if model.get("pipe"):
         return await generate_function_chat_completion(form_data, user=user)
         return await generate_function_chat_completion(form_data, user=user)
     if model["owned_by"] == "ollama":
     if model["owned_by"] == "ollama":
-        print("generate_ollama_chat_completion")
         return await generate_ollama_chat_completion(form_data, user=user)
         return await generate_ollama_chat_completion(form_data, user=user)
     else:
     else:
         return await generate_openai_chat_completion(form_data, user=user)
         return await generate_openai_chat_completion(form_data, user=user)
@@ -1198,6 +1139,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
     def get_priority(function_id):
     def get_priority(function_id):
         function = Functions.get_function_by_id(function_id)
         function = Functions.get_function_by_id(function_id)
         if function is not None and hasattr(function, "valves"):
         if function is not None and hasattr(function, "valves"):
+            # TODO: Fix FunctionModel to include vavles
             return (function.valves if function.valves else {}).get("priority", 0)
             return (function.valves if function.valves else {}).get("priority", 0)
         return 0
         return 0
 
 
@@ -1487,7 +1429,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
         "stream": False,
         "stream": False,
         "max_tokens": 50,
         "max_tokens": 50,
         "chat_id": form_data.get("chat_id", None),
         "chat_id": form_data.get("chat_id", None),
-        "task": str(TASKS.TITLE_GENERATION),
+        "metadata": {"task": str(TASKS.TITLE_GENERATION)},
     }
     }
 
 
     log.debug(payload)
     log.debug(payload)
@@ -1540,7 +1482,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
         "messages": [{"role": "user", "content": content}],
         "messages": [{"role": "user", "content": content}],
         "stream": False,
         "stream": False,
         "max_tokens": 30,
         "max_tokens": 30,
-        "task": str(TASKS.QUERY_GENERATION),
+        "metadata": {"task": str(TASKS.QUERY_GENERATION)},
     }
     }
 
 
     print(payload)
     print(payload)
@@ -1597,7 +1539,7 @@ Message: """{{prompt}}"""
         "stream": False,
         "stream": False,
         "max_tokens": 4,
         "max_tokens": 4,
         "chat_id": form_data.get("chat_id", None),
         "chat_id": form_data.get("chat_id", None),
-        "task": str(TASKS.EMOJI_GENERATION),
+        "metadata": {"task": str(TASKS.EMOJI_GENERATION)},
     }
     }
 
 
     log.debug(payload)
     log.debug(payload)
@@ -1616,41 +1558,6 @@ Message: """{{prompt}}"""
     return await generate_chat_completions(form_data=payload, user=user)
     return await generate_chat_completions(form_data=payload, user=user)
 
 
 
 
-@app.post("/api/task/tools/completions")
-async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_user)):
-    print("get_tools_function_calling")
-
-    model_id = form_data["model"]
-    if model_id not in app.state.MODELS:
-        raise HTTPException(
-            status_code=status.HTTP_404_NOT_FOUND,
-            detail="Model not found",
-        )
-
-    # Check if the user has a custom task model
-    # If the user has a custom task model, use that model
-    model_id = get_task_model_id(model_id)
-
-    print(model_id)
-    template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
-
-    try:
-        context, _, _ = await get_function_call_response(
-            form_data["messages"],
-            form_data.get("files", []),
-            form_data["tool_id"],
-            template,
-            model_id,
-            user,
-        )
-        return context
-    except Exception as e:
-        return JSONResponse(
-            status_code=e.args[0],
-            content={"detail": e.args[1]},
-        )
-
-
 ##################################
 ##################################
 #
 #
 # Pipelines Endpoints
 # Pipelines Endpoints
@@ -1689,7 +1596,7 @@ async def upload_pipeline(
 ):
 ):
     print("upload_pipeline", urlIdx, file.filename)
     print("upload_pipeline", urlIdx, file.filename)
     # Check if the uploaded file is a python file
     # Check if the uploaded file is a python file
-    if not file.filename.endswith(".py"):
+    if not (file.filename and file.filename.endswith(".py")):
         raise HTTPException(
         raise HTTPException(
             status_code=status.HTTP_400_BAD_REQUEST,
             status_code=status.HTTP_400_BAD_REQUEST,
             detail="Only Python (.py) files are allowed.",
             detail="Only Python (.py) files are allowed.",
@@ -2138,7 +2045,10 @@ async def oauth_login(provider: str, request: Request):
     redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
     redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
         "oauth_callback", provider=provider
         "oauth_callback", provider=provider
     )
     )
-    return await oauth.create_client(provider).authorize_redirect(request, redirect_uri)
+    client = oauth.create_client(provider)
+    if client is None:
+        raise HTTPException(404)
+    return await client.authorize_redirect(request, redirect_uri)
 
 
 
 
 # OAuth login logic is as follows:
 # OAuth login logic is as follows:

+ 1 - 1
backend/utils/task.py

@@ -121,6 +121,6 @@ def search_query_generation_template(
     return template
     return template
 
 
 
 
-def tools_function_calling_generation_template(template: str, tools_specs: str) -> str:
+def tool_calling_generation_template(template: str, tools_specs: str) -> str:
     template = template.replace("{{TOOLS}}", tools_specs)
     template = template.replace("{{TOOLS}}", tools_specs)
     return template
     return template