Timothy Jaeryang Baek преди 2 месеца
родител
ревизия
3dde2f67cf
променени са 3 файла, в които са добавени 32 реда и са изтрити 26 реда
  1. 3 3
      backend/open_webui/utils/chat.py
  2. 13 16
      backend/open_webui/utils/filter.py
  3. 16 7
      backend/open_webui/utils/middleware.py

+ 3 - 3
backend/open_webui/utils/chat.py

@@ -203,10 +203,10 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
 
     try:
         result, _ = await process_filter_functions(
-            handler_type="outlet",
-            filter_ids=get_sorted_filter_ids(model),
             request=request,
-            data=data,
+            filter_ids=get_sorted_filter_ids(model),
+            filter_type="outlet",
+            form_data=data,
             extra_params=extra_params,
         )
         return result

+ 13 - 16
backend/open_webui/utils/filter.py

@@ -2,6 +2,7 @@ import inspect
 from open_webui.utils.plugin import load_function_module_by_id
 from open_webui.models.functions import Functions
 
+
 def get_sorted_filter_ids(model):
     def get_priority(function_id):
         function = Functions.get_function_by_id(function_id)
@@ -19,17 +20,14 @@ def get_sorted_filter_ids(model):
         function.id
         for function in Functions.get_functions_by_type("filter", active_only=True)
     ]
-    
+
     filter_ids = [fid for fid in filter_ids if fid in enabled_filter_ids]
     filter_ids.sort(key=get_priority)
     return filter_ids
 
+
 async def process_filter_functions(
-    handler_type,
-    filter_ids,
-    request,
-    data,
-    extra_params
+    request, filter_ids, filter_type, form_data, extra_params
 ):
     skip_files = None
 
@@ -45,7 +43,7 @@ async def process_filter_functions(
             request.app.state.FUNCTIONS[filter_id] = function_module
 
         # Check if the function has a file_handler variable
-        if handler_type == "inlet" and hasattr(function_module, "file_handler"):
+        if filter_type == "inlet" and hasattr(function_module, "file_handler"):
             skip_files = function_module.file_handler
 
         # Apply valves to the function
@@ -56,14 +54,14 @@ async def process_filter_functions(
             )
 
         # Prepare handler function
-        handler = getattr(function_module, handler_type, None)
+        handler = getattr(function_module, filter_type, None)
         if not handler:
             continue
 
         try:
             # Prepare parameters
             sig = inspect.signature(handler)
-            params = {"body": data}
+            params = {"body": form_data}
 
             # Add extra parameters that exist in the handler's signature
             for key in list(extra_params.keys()):
@@ -82,19 +80,18 @@ async def process_filter_functions(
                     except Exception as e:
                         print(e)
 
-
             # Execute handler
             if inspect.iscoroutinefunction(handler):
-                data = await handler(**params)
+                form_data = await handler(**params)
             else:
-                data = handler(**params)
+                form_data = handler(**params)
 
         except Exception as e:
-            print(f"Error in {handler_type} handler {filter_id}: {e}")
+            print(f"Error in {filter_type} handler {filter_id}: {e}")
             raise e
 
     # Handle file cleanup for inlet
-    if skip_files and "files" in data.get("metadata", {}):
-        del data["metadata"]["files"]
+    if skip_files and "files" in form_data.get("metadata", {}):
+        del form_data["metadata"]["files"]
 
-    return data, {}
+    return form_data, {}

+ 16 - 7
backend/open_webui/utils/middleware.py

@@ -694,10 +694,10 @@ async def process_chat_payload(request, form_data, metadata, user, model):
 
     try:
         form_data, flags = await process_filter_functions(
-            handler_type="inlet",
-            filter_ids=get_sorted_filter_ids(model),
             request=request,
-            data=form_data,
+            filter_ids=get_sorted_filter_ids(model),
+            filter_type="inlet",
+            form_data=form_data,
             extra_params=extra_params,
         )
     except Exception as e:
@@ -1039,11 +1039,15 @@ async def process_chat_response(
 
         def split_content_and_whitespace(content):
             content_stripped = content.rstrip()
-            original_whitespace = content[len(content_stripped):] if len(content) > len(content_stripped) else ''
+            original_whitespace = (
+                content[len(content_stripped) :]
+                if len(content) > len(content_stripped)
+                else ""
+            )
             return content_stripped, original_whitespace
 
         def is_opening_code_block(content):
-            backtick_segments = content.split('```')
+            backtick_segments = content.split("```")
             # Even number of segments means the last backticks are opening a new block
             return len(backtick_segments) > 1 and len(backtick_segments) % 2 == 0
 
@@ -1113,10 +1117,15 @@ async def process_chat_response(
                         output = block.get("output", None)
                         lang = attributes.get("lang", "")
 
-                        content_stripped, original_whitespace = split_content_and_whitespace(content)
+                        content_stripped, original_whitespace = (
+                            split_content_and_whitespace(content)
+                        )
                         if is_opening_code_block(content_stripped):
                             # Remove trailing backticks that would open a new block
-                            content = content_stripped.rstrip('`').rstrip() + original_whitespace
+                            content = (
+                                content_stripped.rstrip("`").rstrip()
+                                + original_whitespace
+                            )
                         else:
                             # Keep content as is - either closing backticks or no backticks
                             content = content_stripped + original_whitespace