Browse Source

enh: filter function priority valve support

Timothy J. Baek 10 months ago
parent
commit
8b99870189
1 changed files with 79 additions and 56 deletions
  1. 79 56
      backend/main.py

+ 79 - 56
backend/main.py

@@ -389,6 +389,14 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                 )
             model = app.state.MODELS[model_id]
 
+            def get_priority(function_id):
+                function = Functions.get_function_by_id(function_id)
+                if function is not None and hasattr(function, "valves"):
+                    return (function.valves if function.valves else {}).get(
+                        "priority", 0
+                    )
+                return 0
+
             filter_ids = [
                 function.id
                 for function in Functions.get_functions_by_type(
@@ -400,6 +408,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                 filter_ids.extend(model["info"]["meta"].get("filterIds", []))
                 filter_ids = list(set(filter_ids))
 
+            filter_ids.sort(key=get_priority)
             for filter_id in filter_ids:
                 filter = Functions.get_function_by_id(filter_id)
                 if filter:
@@ -1122,72 +1131,86 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
             else:
                 pass
 
+    def get_priority(function_id):
+        function = Functions.get_function_by_id(function_id)
+        if function is not None and hasattr(function, "valves"):
+            return (function.valves if function.valves else {}).get("priority", 0)
+        return 0
+
+    filter_ids = [
+        function.id
+        for function in Functions.get_functions_by_type("filter", active_only=True)
+    ]
     # Check if the model has any filters
     if "info" in model and "meta" in model["info"]:
-        for filter_id in model["info"]["meta"].get("filterIds", []):
-            filter = Functions.get_function_by_id(filter_id)
-            if filter:
-                if filter_id in webui_app.state.FUNCTIONS:
-                    function_module = webui_app.state.FUNCTIONS[filter_id]
-                else:
-                    function_module, function_type = load_function_module_by_id(
-                        filter_id
-                    )
-                    webui_app.state.FUNCTIONS[filter_id] = function_module
+        filter_ids.extend(model["info"]["meta"].get("filterIds", []))
+        filter_ids = list(set(filter_ids))
 
-                if hasattr(function_module, "valves") and hasattr(
-                    function_module, "Valves"
-                ):
-                    valves = Functions.get_function_valves_by_id(filter_id)
-                    function_module.valves = function_module.Valves(
-                        **(valves if valves else {})
-                    )
+    # Sort filter_ids by priority, using the get_priority function
+    filter_ids.sort(key=get_priority)
 
-                try:
-                    if hasattr(function_module, "outlet"):
-                        outlet = function_module.outlet
-
-                        # Get the signature of the function
-                        sig = inspect.signature(outlet)
-                        params = {"body": data}
-
-                        if "__user__" in sig.parameters:
-                            __user__ = {
-                                "id": user.id,
-                                "email": user.email,
-                                "name": user.name,
-                                "role": user.role,
-                            }
+    for filter_id in filter_ids:
+        filter = Functions.get_function_by_id(filter_id)
+        if filter:
+            if filter_id in webui_app.state.FUNCTIONS:
+                function_module = webui_app.state.FUNCTIONS[filter_id]
+            else:
+                function_module, function_type = load_function_module_by_id(filter_id)
+                webui_app.state.FUNCTIONS[filter_id] = function_module
 
-                            try:
-                                if hasattr(function_module, "UserValves"):
-                                    __user__["valves"] = function_module.UserValves(
-                                        **Functions.get_user_valves_by_id_and_user_id(
-                                            filter_id, user.id
-                                        )
+            if hasattr(function_module, "valves") and hasattr(
+                function_module, "Valves"
+            ):
+                valves = Functions.get_function_valves_by_id(filter_id)
+                function_module.valves = function_module.Valves(
+                    **(valves if valves else {})
+                )
+
+            try:
+                if hasattr(function_module, "outlet"):
+                    outlet = function_module.outlet
+
+                    # Get the signature of the function
+                    sig = inspect.signature(outlet)
+                    params = {"body": data}
+
+                    if "__user__" in sig.parameters:
+                        __user__ = {
+                            "id": user.id,
+                            "email": user.email,
+                            "name": user.name,
+                            "role": user.role,
+                        }
+
+                        try:
+                            if hasattr(function_module, "UserValves"):
+                                __user__["valves"] = function_module.UserValves(
+                                    **Functions.get_user_valves_by_id_and_user_id(
+                                        filter_id, user.id
                                     )
-                            except Exception as e:
-                                print(e)
+                                )
+                        except Exception as e:
+                            print(e)
 
-                            params = {**params, "__user__": __user__}
+                        params = {**params, "__user__": __user__}
 
-                        if "__id__" in sig.parameters:
-                            params = {
-                                **params,
-                                "__id__": filter_id,
-                            }
+                    if "__id__" in sig.parameters:
+                        params = {
+                            **params,
+                            "__id__": filter_id,
+                        }
 
-                        if inspect.iscoroutinefunction(outlet):
-                            data = await outlet(**params)
-                        else:
-                            data = outlet(**params)
+                    if inspect.iscoroutinefunction(outlet):
+                        data = await outlet(**params)
+                    else:
+                        data = outlet(**params)
 
-                except Exception as e:
-                    print(f"Error: {e}")
-                    return JSONResponse(
-                        status_code=status.HTTP_400_BAD_REQUEST,
-                        content={"detail": str(e)},
-                    )
+            except Exception as e:
+                print(f"Error: {e}")
+                return JSONResponse(
+                    status_code=status.HTTP_400_BAD_REQUEST,
+                    content={"detail": str(e)},
+                )
 
     return data