|
@@ -68,6 +68,10 @@ from open_webui.utils.misc import (
|
|
|
)
|
|
|
from open_webui.utils.tools import get_tools
|
|
|
from open_webui.utils.plugin import load_function_module_by_id
|
|
|
+from open_webui.utils.filter import (
|
|
|
+ get_sorted_filter_ids,
|
|
|
+ process_filter_functions,
|
|
|
+)
|
|
|
|
|
|
|
|
|
from open_webui.tasks import create_task
|
|
@@ -91,99 +95,6 @@ log = logging.getLogger(__name__)
|
|
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
|
|
|
|
|
|
|
|
-async def chat_completion_filter_functions_handler(request, body, model, extra_params):
|
|
|
- skip_files = None
|
|
|
-
|
|
|
- def get_filter_function_ids(model):
|
|
|
- def get_priority(function_id):
|
|
|
- function = Functions.get_function_by_id(function_id)
|
|
|
- if function is not None and hasattr(function, "valves"):
|
|
|
- # TODO: Fix FunctionModel
|
|
|
- return (function.valves if function.valves else {}).get("priority", 0)
|
|
|
- return 0
|
|
|
-
|
|
|
- filter_ids = [
|
|
|
- function.id for function in Functions.get_global_filter_functions()
|
|
|
- ]
|
|
|
- if "info" in model and "meta" in model["info"]:
|
|
|
- filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
|
|
- filter_ids = list(set(filter_ids))
|
|
|
-
|
|
|
- enabled_filter_ids = [
|
|
|
- function.id
|
|
|
- for function in Functions.get_functions_by_type("filter", active_only=True)
|
|
|
- ]
|
|
|
-
|
|
|
- filter_ids = [
|
|
|
- filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
|
|
|
- ]
|
|
|
-
|
|
|
- filter_ids.sort(key=get_priority)
|
|
|
- return filter_ids
|
|
|
-
|
|
|
- filter_ids = get_filter_function_ids(model)
|
|
|
- for filter_id in filter_ids:
|
|
|
- filter = Functions.get_function_by_id(filter_id)
|
|
|
- if not filter:
|
|
|
- continue
|
|
|
-
|
|
|
- if filter_id in request.app.state.FUNCTIONS:
|
|
|
- function_module = request.app.state.FUNCTIONS[filter_id]
|
|
|
- else:
|
|
|
- function_module, _, _ = load_function_module_by_id(filter_id)
|
|
|
- request.app.state.FUNCTIONS[filter_id] = function_module
|
|
|
-
|
|
|
- # Check if the function has a file_handler variable
|
|
|
- if hasattr(function_module, "file_handler"):
|
|
|
- skip_files = function_module.file_handler
|
|
|
-
|
|
|
- # Apply valves to the function
|
|
|
- if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
|
|
- valves = Functions.get_function_valves_by_id(filter_id)
|
|
|
- function_module.valves = function_module.Valves(
|
|
|
- **(valves if valves else {})
|
|
|
- )
|
|
|
-
|
|
|
- if hasattr(function_module, "inlet"):
|
|
|
- try:
|
|
|
- inlet = function_module.inlet
|
|
|
-
|
|
|
- # Create a dictionary of parameters to be passed to the function
|
|
|
- params = {"body": body} | {
|
|
|
- k: v
|
|
|
- for k, v in {
|
|
|
- **extra_params,
|
|
|
- "__model__": model,
|
|
|
- "__id__": filter_id,
|
|
|
- }.items()
|
|
|
- if k in inspect.signature(inlet).parameters
|
|
|
- }
|
|
|
-
|
|
|
- if "__user__" in params and hasattr(function_module, "UserValves"):
|
|
|
- try:
|
|
|
- params["__user__"]["valves"] = function_module.UserValves(
|
|
|
- **Functions.get_user_valves_by_id_and_user_id(
|
|
|
- filter_id, params["__user__"]["id"]
|
|
|
- )
|
|
|
- )
|
|
|
- except Exception as e:
|
|
|
- print(e)
|
|
|
-
|
|
|
- if inspect.iscoroutinefunction(inlet):
|
|
|
- body = await inlet(**params)
|
|
|
- else:
|
|
|
- body = inlet(**params)
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- print(f"Error: {e}")
|
|
|
- raise e
|
|
|
-
|
|
|
- if skip_files and "files" in body.get("metadata", {}):
|
|
|
- del body["metadata"]["files"]
|
|
|
-
|
|
|
- return body, {}
|
|
|
-
|
|
|
-
|
|
|
async def chat_completion_tools_handler(
|
|
|
request: Request, body: dict, user: UserModel, models, tools
|
|
|
) -> tuple[dict, dict]:
|
|
@@ -782,8 +693,12 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
|
|
)
|
|
|
|
|
|
try:
|
|
|
- form_data, flags = await chat_completion_filter_functions_handler(
|
|
|
- request, form_data, model, extra_params
|
|
|
+ form_data, flags = await process_filter_functions(
|
|
|
+ request=request,
|
|
|
+ filter_ids=get_sorted_filter_ids(model),
|
|
|
+ filter_type="inlet",
|
|
|
+ form_data=form_data,
|
|
|
+ extra_params=extra_params,
|
|
|
)
|
|
|
except Exception as e:
|
|
|
raise Exception(f"Error: {e}")
|
|
@@ -1124,11 +1039,15 @@ async def process_chat_response(
|
|
|
|
|
|
def split_content_and_whitespace(content):
|
|
|
content_stripped = content.rstrip()
|
|
|
- original_whitespace = content[len(content_stripped):] if len(content) > len(content_stripped) else ''
|
|
|
+ original_whitespace = (
|
|
|
+ content[len(content_stripped) :]
|
|
|
+ if len(content) > len(content_stripped)
|
|
|
+ else ""
|
|
|
+ )
|
|
|
return content_stripped, original_whitespace
|
|
|
|
|
|
def is_opening_code_block(content):
|
|
|
- backtick_segments = content.split('```')
|
|
|
+ backtick_segments = content.split("```")
|
|
|
# Even number of segments means the last backticks are opening a new block
|
|
|
return len(backtick_segments) > 1 and len(backtick_segments) % 2 == 0
|
|
|
|
|
@@ -1198,10 +1117,15 @@ async def process_chat_response(
|
|
|
output = block.get("output", None)
|
|
|
lang = attributes.get("lang", "")
|
|
|
|
|
|
- content_stripped, original_whitespace = split_content_and_whitespace(content)
|
|
|
+ content_stripped, original_whitespace = (
|
|
|
+ split_content_and_whitespace(content)
|
|
|
+ )
|
|
|
if is_opening_code_block(content_stripped):
|
|
|
# Remove trailing backticks that would open a new block
|
|
|
- content = content_stripped.rstrip('`').rstrip() + original_whitespace
|
|
|
+ content = (
|
|
|
+ content_stripped.rstrip("`").rstrip()
|
|
|
+ + original_whitespace
|
|
|
+ )
|
|
|
else:
|
|
|
# Keep content as is - either closing backticks or no backticks
|
|
|
content = content_stripped + original_whitespace
|