|
@@ -44,6 +44,10 @@ from open_webui.utils.response import (
|
|
|
convert_response_ollama_to_openai,
|
|
|
convert_streaming_response_ollama_to_openai,
|
|
|
)
|
|
|
+from open_webui.utils.filter import (
|
|
|
+ get_sorted_filter_ids,
|
|
|
+ process_filter_functions,
|
|
|
+)
|
|
|
|
|
|
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
|
|
|
|
|
@@ -177,116 +181,37 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
|
|
|
except Exception as e:
|
|
|
return Exception(f"Error: {e}")
|
|
|
|
|
|
- __event_emitter__ = get_event_emitter(
|
|
|
- {
|
|
|
- "chat_id": data["chat_id"],
|
|
|
- "message_id": data["id"],
|
|
|
- "session_id": data["session_id"],
|
|
|
- "user_id": user.id,
|
|
|
- }
|
|
|
- )
|
|
|
-
|
|
|
- __event_call__ = get_event_call(
|
|
|
- {
|
|
|
- "chat_id": data["chat_id"],
|
|
|
- "message_id": data["id"],
|
|
|
- "session_id": data["session_id"],
|
|
|
- "user_id": user.id,
|
|
|
- }
|
|
|
- )
|
|
|
-
|
|
|
- def get_priority(function_id):
|
|
|
- function = Functions.get_function_by_id(function_id)
|
|
|
- if function is not None and hasattr(function, "valves"):
|
|
|
- # TODO: Fix FunctionModel to include vavles
|
|
|
- return (function.valves if function.valves else {}).get("priority", 0)
|
|
|
- return 0
|
|
|
-
|
|
|
- filter_ids = [function.id for function in Functions.get_global_filter_functions()]
|
|
|
- if "info" in model and "meta" in model["info"]:
|
|
|
- filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
|
|
- filter_ids = list(set(filter_ids))
|
|
|
-
|
|
|
- enabled_filter_ids = [
|
|
|
- function.id
|
|
|
- for function in Functions.get_functions_by_type("filter", active_only=True)
|
|
|
- ]
|
|
|
- filter_ids = [
|
|
|
- filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
|
|
|
- ]
|
|
|
-
|
|
|
- # Sort filter_ids by priority, using the get_priority function
|
|
|
- filter_ids.sort(key=get_priority)
|
|
|
-
|
|
|
- for filter_id in filter_ids:
|
|
|
- filter = Functions.get_function_by_id(filter_id)
|
|
|
- if not filter:
|
|
|
- continue
|
|
|
-
|
|
|
- if filter_id in request.app.state.FUNCTIONS:
|
|
|
- function_module = request.app.state.FUNCTIONS[filter_id]
|
|
|
- else:
|
|
|
- function_module, _, _ = load_function_module_by_id(filter_id)
|
|
|
- request.app.state.FUNCTIONS[filter_id] = function_module
|
|
|
-
|
|
|
- if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
|
|
- valves = Functions.get_function_valves_by_id(filter_id)
|
|
|
- function_module.valves = function_module.Valves(
|
|
|
- **(valves if valves else {})
|
|
|
- )
|
|
|
-
|
|
|
- if not hasattr(function_module, "outlet"):
|
|
|
- continue
|
|
|
- try:
|
|
|
- outlet = function_module.outlet
|
|
|
-
|
|
|
- # Get the signature of the function
|
|
|
- sig = inspect.signature(outlet)
|
|
|
- params = {"body": data}
|
|
|
-
|
|
|
- # Extra parameters to be passed to the function
|
|
|
- extra_params = {
|
|
|
- "__model__": model,
|
|
|
- "__id__": filter_id,
|
|
|
- "__event_emitter__": __event_emitter__,
|
|
|
- "__event_call__": __event_call__,
|
|
|
- "__request__": request,
|
|
|
- }
|
|
|
+ metadata = {
|
|
|
+ "chat_id": data["chat_id"],
|
|
|
+ "message_id": data["id"],
|
|
|
+ "session_id": data["session_id"],
|
|
|
+ "user_id": user.id,
|
|
|
+ }
|
|
|
+
|
|
|
+ extra_params = {
|
|
|
+ "__event_emitter__": get_event_emitter(metadata),
|
|
|
+ "__event_call__": get_event_call(metadata),
|
|
|
+ "__user__": {
|
|
|
+ "id": user.id,
|
|
|
+ "email": user.email,
|
|
|
+ "name": user.name,
|
|
|
+ "role": user.role,
|
|
|
+ },
|
|
|
+ "__metadata__": metadata,
|
|
|
+ "__request__": request,
|
|
|
+ }
|
|
|
|
|
|
- # Add extra params in contained in function signature
|
|
|
- for key, value in extra_params.items():
|
|
|
- if key in sig.parameters:
|
|
|
- params[key] = value
|
|
|
-
|
|
|
- if "__user__" in sig.parameters:
|
|
|
- __user__ = {
|
|
|
- "id": user.id,
|
|
|
- "email": user.email,
|
|
|
- "name": user.name,
|
|
|
- "role": user.role,
|
|
|
- }
|
|
|
-
|
|
|
- try:
|
|
|
- if hasattr(function_module, "UserValves"):
|
|
|
- __user__["valves"] = function_module.UserValves(
|
|
|
- **Functions.get_user_valves_by_id_and_user_id(
|
|
|
- filter_id, user.id
|
|
|
- )
|
|
|
- )
|
|
|
- except Exception as e:
|
|
|
- print(e)
|
|
|
-
|
|
|
- params = {**params, "__user__": __user__}
|
|
|
-
|
|
|
- if inspect.iscoroutinefunction(outlet):
|
|
|
- data = await outlet(**params)
|
|
|
- else:
|
|
|
- data = outlet(**params)
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- return Exception(f"Error: {e}")
|
|
|
-
|
|
|
- return data
|
|
|
+ try:
|
|
|
+ result, _ = await process_filter_functions(
|
|
|
+ handler_type="outlet",
|
|
|
+ filter_ids=get_sorted_filter_ids(model),
|
|
|
+ request=request,
|
|
|
+ data=data,
|
|
|
+ extra_params=extra_params,
|
|
|
+ )
|
|
|
+ return result
|
|
|
+ except Exception as e:
|
|
|
+ return Exception(f"Error: {e}")
|
|
|
|
|
|
|
|
|
async def chat_action(request: Request, action_id: str, form_data: dict, user: Any):
|