Browse Source

Merge pull request #9631 from XingjianXie/remove_inlet_outlet_duplication

Refactor common code between inlet and outlet
Timothy Jaeryang Baek 2 months ago
parent
commit
79c0b45543
3 changed files with 154 additions and 208 deletions
  1. 34 109
      backend/open_webui/utils/chat.py
  2. 97 0
      backend/open_webui/utils/filter.py
  3. 23 99
      backend/open_webui/utils/middleware.py

+ 34 - 109
backend/open_webui/utils/chat.py

@@ -44,6 +44,10 @@ from open_webui.utils.response import (
     convert_response_ollama_to_openai,
     convert_streaming_response_ollama_to_openai,
 )
+from open_webui.utils.filter import (
+    get_sorted_filter_ids,
+    process_filter_functions,
+)
 
 from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
 
@@ -177,116 +181,37 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
     except Exception as e:
         return Exception(f"Error: {e}")
 
-    __event_emitter__ = get_event_emitter(
-        {
-            "chat_id": data["chat_id"],
-            "message_id": data["id"],
-            "session_id": data["session_id"],
-            "user_id": user.id,
-        }
-    )
-
-    __event_call__ = get_event_call(
-        {
-            "chat_id": data["chat_id"],
-            "message_id": data["id"],
-            "session_id": data["session_id"],
-            "user_id": user.id,
-        }
-    )
-
-    def get_priority(function_id):
-        function = Functions.get_function_by_id(function_id)
-        if function is not None and hasattr(function, "valves"):
-            # TODO: Fix FunctionModel to include vavles
-            return (function.valves if function.valves else {}).get("priority", 0)
-        return 0
-
-    filter_ids = [function.id for function in Functions.get_global_filter_functions()]
-    if "info" in model and "meta" in model["info"]:
-        filter_ids.extend(model["info"]["meta"].get("filterIds", []))
-        filter_ids = list(set(filter_ids))
-
-    enabled_filter_ids = [
-        function.id
-        for function in Functions.get_functions_by_type("filter", active_only=True)
-    ]
-    filter_ids = [
-        filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
-    ]
-
-    # Sort filter_ids by priority, using the get_priority function
-    filter_ids.sort(key=get_priority)
-
-    for filter_id in filter_ids:
-        filter = Functions.get_function_by_id(filter_id)
-        if not filter:
-            continue
-
-        if filter_id in request.app.state.FUNCTIONS:
-            function_module = request.app.state.FUNCTIONS[filter_id]
-        else:
-            function_module, _, _ = load_function_module_by_id(filter_id)
-            request.app.state.FUNCTIONS[filter_id] = function_module
-
-        if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
-            valves = Functions.get_function_valves_by_id(filter_id)
-            function_module.valves = function_module.Valves(
-                **(valves if valves else {})
-            )
-
-        if not hasattr(function_module, "outlet"):
-            continue
-        try:
-            outlet = function_module.outlet
-
-            # Get the signature of the function
-            sig = inspect.signature(outlet)
-            params = {"body": data}
-
-            # Extra parameters to be passed to the function
-            extra_params = {
-                "__model__": model,
-                "__id__": filter_id,
-                "__event_emitter__": __event_emitter__,
-                "__event_call__": __event_call__,
-                "__request__": request,
-            }
+    metadata = {
+        "chat_id": data["chat_id"],
+        "message_id": data["id"],
+        "session_id": data["session_id"],
+        "user_id": user.id,
+    }
+
+    extra_params = {
+        "__event_emitter__": get_event_emitter(metadata),
+        "__event_call__": get_event_call(metadata),
+        "__user__": {
+            "id": user.id,
+            "email": user.email,
+            "name": user.name,
+            "role": user.role,
+        },
+        "__metadata__": metadata,
+        "__request__": request,
+    }
 
-            # Add extra params in contained in function signature
-            for key, value in extra_params.items():
-                if key in sig.parameters:
-                    params[key] = value
-
-            if "__user__" in sig.parameters:
-                __user__ = {
-                    "id": user.id,
-                    "email": user.email,
-                    "name": user.name,
-                    "role": user.role,
-                }
-
-                try:
-                    if hasattr(function_module, "UserValves"):
-                        __user__["valves"] = function_module.UserValves(
-                            **Functions.get_user_valves_by_id_and_user_id(
-                                filter_id, user.id
-                            )
-                        )
-                except Exception as e:
-                    print(e)
-
-                params = {**params, "__user__": __user__}
-
-            if inspect.iscoroutinefunction(outlet):
-                data = await outlet(**params)
-            else:
-                data = outlet(**params)
-
-        except Exception as e:
-            return Exception(f"Error: {e}")
-
-    return data
+    try:
+        result, _ = await process_filter_functions(
+            request=request,
+            filter_ids=get_sorted_filter_ids(model),
+            filter_type="outlet",
+            form_data=data,
+            extra_params=extra_params,
+        )
+        return result
+    except Exception as e:
+        return Exception(f"Error: {e}")
 
 
 async def chat_action(request: Request, action_id: str, form_data: dict, user: Any):

+ 97 - 0
backend/open_webui/utils/filter.py

@@ -0,0 +1,97 @@
+import inspect
+from open_webui.utils.plugin import load_function_module_by_id
+from open_webui.models.functions import Functions
+
+
+def get_sorted_filter_ids(model):
+    def get_priority(function_id):
+        function = Functions.get_function_by_id(function_id)
+        if function is not None and hasattr(function, "valves"):
+            # TODO: Fix FunctionModel to include vavles
+            return (function.valves if function.valves else {}).get("priority", 0)
+        return 0
+
+    filter_ids = [function.id for function in Functions.get_global_filter_functions()]
+    if "info" in model and "meta" in model["info"]:
+        filter_ids.extend(model["info"]["meta"].get("filterIds", []))
+        filter_ids = list(set(filter_ids))
+
+    enabled_filter_ids = [
+        function.id
+        for function in Functions.get_functions_by_type("filter", active_only=True)
+    ]
+
+    filter_ids = [fid for fid in filter_ids if fid in enabled_filter_ids]
+    filter_ids.sort(key=get_priority)
+    return filter_ids
+
+
+async def process_filter_functions(
+    request, filter_ids, filter_type, form_data, extra_params
+):
+    skip_files = None
+
+    for filter_id in filter_ids:
+        filter = Functions.get_function_by_id(filter_id)
+        if not filter:
+            continue
+
+        if filter_id in request.app.state.FUNCTIONS:
+            function_module = request.app.state.FUNCTIONS[filter_id]
+        else:
+            function_module, _, _ = load_function_module_by_id(filter_id)
+            request.app.state.FUNCTIONS[filter_id] = function_module
+
+        # Check if the function has a file_handler variable
+        if filter_type == "inlet" and hasattr(function_module, "file_handler"):
+            skip_files = function_module.file_handler
+
+        # Apply valves to the function
+        if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
+            valves = Functions.get_function_valves_by_id(filter_id)
+            function_module.valves = function_module.Valves(
+                **(valves if valves else {})
+            )
+
+        # Prepare handler function
+        handler = getattr(function_module, filter_type, None)
+        if not handler:
+            continue
+
+        try:
+            # Prepare parameters
+            sig = inspect.signature(handler)
+            params = {"body": form_data}
+
+            # Add extra parameters that exist in the handler's signature
+            for key in list(extra_params.keys()):
+                if key in sig.parameters:
+                    params[key] = extra_params[key]
+
+            # Handle user parameters
+            if "__user__" in sig.parameters:
+                if hasattr(function_module, "UserValves"):
+                    try:
+                        params["__user__"]["valves"] = function_module.UserValves(
+                            **Functions.get_user_valves_by_id_and_user_id(
+                                filter_id, params["__user__"]["id"]
+                            )
+                        )
+                    except Exception as e:
+                        print(e)
+
+            # Execute handler
+            if inspect.iscoroutinefunction(handler):
+                form_data = await handler(**params)
+            else:
+                form_data = handler(**params)
+
+        except Exception as e:
+            print(f"Error in {filter_type} handler {filter_id}: {e}")
+            raise e
+
+    # Handle file cleanup for inlet
+    if skip_files and "files" in form_data.get("metadata", {}):
+        del form_data["metadata"]["files"]
+
+    return form_data, {}

+ 23 - 99
backend/open_webui/utils/middleware.py

@@ -68,6 +68,10 @@ from open_webui.utils.misc import (
 )
 from open_webui.utils.tools import get_tools
 from open_webui.utils.plugin import load_function_module_by_id
+from open_webui.utils.filter import (
+    get_sorted_filter_ids,
+    process_filter_functions,
+)
 
 
 from open_webui.tasks import create_task
@@ -91,99 +95,6 @@ log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MAIN"])
 
 
-async def chat_completion_filter_functions_handler(request, body, model, extra_params):
-    skip_files = None
-
-    def get_filter_function_ids(model):
-        def get_priority(function_id):
-            function = Functions.get_function_by_id(function_id)
-            if function is not None and hasattr(function, "valves"):
-                # TODO: Fix FunctionModel
-                return (function.valves if function.valves else {}).get("priority", 0)
-            return 0
-
-        filter_ids = [
-            function.id for function in Functions.get_global_filter_functions()
-        ]
-        if "info" in model and "meta" in model["info"]:
-            filter_ids.extend(model["info"]["meta"].get("filterIds", []))
-            filter_ids = list(set(filter_ids))
-
-        enabled_filter_ids = [
-            function.id
-            for function in Functions.get_functions_by_type("filter", active_only=True)
-        ]
-
-        filter_ids = [
-            filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
-        ]
-
-        filter_ids.sort(key=get_priority)
-        return filter_ids
-
-    filter_ids = get_filter_function_ids(model)
-    for filter_id in filter_ids:
-        filter = Functions.get_function_by_id(filter_id)
-        if not filter:
-            continue
-
-        if filter_id in request.app.state.FUNCTIONS:
-            function_module = request.app.state.FUNCTIONS[filter_id]
-        else:
-            function_module, _, _ = load_function_module_by_id(filter_id)
-            request.app.state.FUNCTIONS[filter_id] = function_module
-
-        # Check if the function has a file_handler variable
-        if hasattr(function_module, "file_handler"):
-            skip_files = function_module.file_handler
-
-        # Apply valves to the function
-        if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
-            valves = Functions.get_function_valves_by_id(filter_id)
-            function_module.valves = function_module.Valves(
-                **(valves if valves else {})
-            )
-
-        if hasattr(function_module, "inlet"):
-            try:
-                inlet = function_module.inlet
-
-                # Create a dictionary of parameters to be passed to the function
-                params = {"body": body} | {
-                    k: v
-                    for k, v in {
-                        **extra_params,
-                        "__model__": model,
-                        "__id__": filter_id,
-                    }.items()
-                    if k in inspect.signature(inlet).parameters
-                }
-
-                if "__user__" in params and hasattr(function_module, "UserValves"):
-                    try:
-                        params["__user__"]["valves"] = function_module.UserValves(
-                            **Functions.get_user_valves_by_id_and_user_id(
-                                filter_id, params["__user__"]["id"]
-                            )
-                        )
-                    except Exception as e:
-                        print(e)
-
-                if inspect.iscoroutinefunction(inlet):
-                    body = await inlet(**params)
-                else:
-                    body = inlet(**params)
-
-            except Exception as e:
-                print(f"Error: {e}")
-                raise e
-
-    if skip_files and "files" in body.get("metadata", {}):
-        del body["metadata"]["files"]
-
-    return body, {}
-
-
 async def chat_completion_tools_handler(
     request: Request, body: dict, user: UserModel, models, tools
 ) -> tuple[dict, dict]:
@@ -782,8 +693,12 @@ async def process_chat_payload(request, form_data, metadata, user, model):
             )
 
     try:
-        form_data, flags = await chat_completion_filter_functions_handler(
-            request, form_data, model, extra_params
+        form_data, flags = await process_filter_functions(
+            request=request,
+            filter_ids=get_sorted_filter_ids(model),
+            filter_type="inlet",
+            form_data=form_data,
+            extra_params=extra_params,
         )
     except Exception as e:
         raise Exception(f"Error: {e}")
@@ -1124,11 +1039,15 @@ async def process_chat_response(
 
         def split_content_and_whitespace(content):
             content_stripped = content.rstrip()
-            original_whitespace = content[len(content_stripped):] if len(content) > len(content_stripped) else ''
+            original_whitespace = (
+                content[len(content_stripped) :]
+                if len(content) > len(content_stripped)
+                else ""
+            )
             return content_stripped, original_whitespace
 
         def is_opening_code_block(content):
-            backtick_segments = content.split('```')
+            backtick_segments = content.split("```")
             # Even number of segments means the last backticks are opening a new block
             return len(backtick_segments) > 1 and len(backtick_segments) % 2 == 0
 
@@ -1198,10 +1117,15 @@ async def process_chat_response(
                         output = block.get("output", None)
                         lang = attributes.get("lang", "")
 
-                        content_stripped, original_whitespace = split_content_and_whitespace(content)
+                        content_stripped, original_whitespace = (
+                            split_content_and_whitespace(content)
+                        )
                         if is_opening_code_block(content_stripped):
                             # Remove trailing backticks that would open a new block
-                            content = content_stripped.rstrip('`').rstrip() + original_whitespace
+                            content = (
+                                content_stripped.rstrip("`").rstrip()
+                                + original_whitespace
+                            )
                         else:
                             # Keep content as is - either closing backticks or no backticks
                             content = content_stripped + original_whitespace