Timothy J. Baek 10 місяців тому
батько
коміт
3a629ffe00
2 змінених файлів з 65 додано та 56 видалено
  1. 63 55
      backend/main.py
  2. 2 1
      src/lib/components/workspace/Functions.svelte

+ 63 - 55
backend/main.py

@@ -376,70 +376,77 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                 )
             model = app.state.MODELS[model_id]
 
+            filter_ids = [
+                function.id
+                for function in Functions.get_functions_by_type(
+                    "filter", active_only=True
+                )
+            ]
             # Check if the model has any filters
             if "info" in model and "meta" in model["info"]:
-                for filter_id in model["info"]["meta"].get("filterIds", []):
-                    filter = Functions.get_function_by_id(filter_id)
-                    if filter:
-                        if filter_id in webui_app.state.FUNCTIONS:
-                            function_module = webui_app.state.FUNCTIONS[filter_id]
-                        else:
-                            function_module, function_type = load_function_module_by_id(
-                                filter_id
-                            )
-                            webui_app.state.FUNCTIONS[filter_id] = function_module
+                filter_ids.extend(model["info"]["meta"].get("filterIds", []))
+                filter_ids = list(set(filter_ids))
+
+            for filter_id in filter_ids:
+                filter = Functions.get_function_by_id(filter_id)
+                if filter:
+                    if filter_id in webui_app.state.FUNCTIONS:
+                        function_module = webui_app.state.FUNCTIONS[filter_id]
+                    else:
+                        function_module, function_type = load_function_module_by_id(
+                            filter_id
+                        )
+                        webui_app.state.FUNCTIONS[filter_id] = function_module
 
-                        # Check if the function has a file_handler variable
-                        if hasattr(function_module, "file_handler"):
-                            skip_files = function_module.file_handler
+                    # Check if the function has a file_handler variable
+                    if hasattr(function_module, "file_handler"):
+                        skip_files = function_module.file_handler
 
-                        try:
-                            if hasattr(function_module, "inlet"):
-                                inlet = function_module.inlet
-
-                                # Get the signature of the function
-                                sig = inspect.signature(inlet)
-                                params = {"body": data}
-
-                                if "__user__" in sig.parameters:
-                                    __user__ = {
-                                        "id": user.id,
-                                        "email": user.email,
-                                        "name": user.name,
-                                        "role": user.role,
-                                    }
-
-                                    try:
-                                        if hasattr(function_module, "UserValves"):
-                                            __user__["valves"] = (
-                                                function_module.UserValves(
-                                                    **Functions.get_user_valves_by_id_and_user_id(
-                                                        filter_id, user.id
-                                                    )
-                                                )
+                    try:
+                        if hasattr(function_module, "inlet"):
+                            inlet = function_module.inlet
+
+                            # Get the signature of the function
+                            sig = inspect.signature(inlet)
+                            params = {"body": data}
+
+                            if "__user__" in sig.parameters:
+                                __user__ = {
+                                    "id": user.id,
+                                    "email": user.email,
+                                    "name": user.name,
+                                    "role": user.role,
+                                }
+
+                                try:
+                                    if hasattr(function_module, "UserValves"):
+                                        __user__["valves"] = function_module.UserValves(
+                                            **Functions.get_user_valves_by_id_and_user_id(
+                                                filter_id, user.id
                                             )
-                                    except Exception as e:
-                                        print(e)
+                                        )
+                                except Exception as e:
+                                    print(e)
 
-                                    params = {**params, "__user__": __user__}
+                                params = {**params, "__user__": __user__}
 
-                                if "__id__" in sig.parameters:
-                                    params = {
-                                        **params,
-                                        "__id__": filter_id,
-                                    }
+                            if "__id__" in sig.parameters:
+                                params = {
+                                    **params,
+                                    "__id__": filter_id,
+                                }
 
-                                if inspect.iscoroutinefunction(inlet):
-                                    data = await inlet(**params)
-                                else:
-                                    data = inlet(**params)
+                            if inspect.iscoroutinefunction(inlet):
+                                data = await inlet(**params)
+                            else:
+                                data = inlet(**params)
 
-                        except Exception as e:
-                            print(f"Error: {e}")
-                            return JSONResponse(
-                                status_code=status.HTTP_400_BAD_REQUEST,
-                                content={"detail": str(e)},
-                            )
+                    except Exception as e:
+                        print(f"Error: {e}")
+                        return JSONResponse(
+                            status_code=status.HTTP_400_BAD_REQUEST,
+                            content={"detail": str(e)},
+                        )
 
             # Set the task model
             task_model_id = data["model"]
@@ -863,6 +870,7 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
 
     pipe = model.get("pipe")
     if pipe:
+
         async def job():
             pipe_id = form_data["model"]
             if "." in pipe_id:

+ 2 - 1
src/lib/components/workspace/Functions.svelte

@@ -227,8 +227,9 @@
 				<div class=" self-center mx-1">
 					<Switch
 						bind:state={func.is_active}
-						on:change={(e) => {
+						on:change={async (e) => {
 							toggleFunctionById(localStorage.token, func.id);
+							models.set(await getModels(localStorage.token));
 						}}
 					/>
 				</div>